Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GlobalPhaseGate #4697

Merged
merged 23 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@
generalized_amplitude_damp,
GeneralizedAmplitudeDampingChannel,
givens,
GlobalPhaseGate,
GlobalPhaseOperation,
global_phase_operation,
H,
HPowGate,
I,
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def default_namer(label_entity):
diagram.write(0, i, name)
first_annotation_row = max(label_map.values(), default=0) + 1

if any(isinstance(op.untagged, cirq.GlobalPhaseOperation) for op in self.all_operations()):
if any(isinstance(op.gate, cirq.GlobalPhaseGate) for op in self.all_operations()):
diagram.write(0, max(label_map.values(), default=0) + 1, 'global phase:')
first_annotation_row += 1

Expand Down Expand Up @@ -2359,7 +2359,7 @@ def _get_moment_annotations(
if op.qubits:
continue
op = op.untagged
if isinstance(op, ops.GlobalPhaseOperation):
if isinstance(op.gate, ops.GlobalPhaseGate):
continue
if isinstance(op, CircuitOperation):
for m in op.circuit:
Expand Down Expand Up @@ -2491,8 +2491,8 @@ def _draw_moment_in_diagram(


def _get_global_phase_and_tags_for_op(op: 'cirq.Operation') -> Tuple[Optional[complex], List[Any]]:
if isinstance(op.untagged, ops.GlobalPhaseOperation):
return complex(op.untagged.coefficient), list(op.tags)
if isinstance(op.gate, ops.GlobalPhaseGate):
return complex(op.gate.coefficient), list(op.tags)
elif isinstance(op.untagged, CircuitOperation):
op_phase, op_tags = _get_global_phase_and_tags_for_ops(op.untagged.circuit.all_operations())
return op_phase, list(op.tags) + op_tags
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,11 @@ def test_string_format():
assert str(op0) == f"[ ]"

fc0_global_phase_inner = cirq.FrozenCircuit(
cirq.GlobalPhaseOperation(1j), cirq.GlobalPhaseOperation(1j)
cirq.global_phase_operation(1j), cirq.global_phase_operation(1j)
)
op0_global_phase_inner = cirq.CircuitOperation(fc0_global_phase_inner)
fc0_global_phase_outer = cirq.FrozenCircuit(
op0_global_phase_inner, cirq.GlobalPhaseOperation(1j)
op0_global_phase_inner, cirq.global_phase_operation(1j)
)
op0_global_phase_outer = cirq.CircuitOperation(fc0_global_phase_outer)
assert (
Expand Down
8 changes: 5 additions & 3 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,7 +2375,7 @@ def test_diagram_wgate_none_precision(circuit_cls):
@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_diagram_global_phase(circuit_cls):
qa = cirq.NamedQubit('a')
global_phase = cirq.GlobalPhaseOperation(coefficient=1j)
global_phase = cirq.global_phase_operation(coefficient=1j)
c = circuit_cls([global_phase])
cirq.testing.assert_has_diagram(
c, "\n\nglobal phase: 0.5pi", use_unicode_characters=False, precision=2
Expand Down Expand Up @@ -2408,7 +2408,9 @@ def test_diagram_global_phase(circuit_cls):

c = circuit_cls(
cirq.X(cirq.LineQubit(2)),
cirq.CircuitOperation(circuit_cls(cirq.GlobalPhaseOperation(-1).with_tags("tag")).freeze()),
cirq.CircuitOperation(
circuit_cls(cirq.global_phase_operation(-1).with_tags("tag")).freeze()
),
)
cirq.testing.assert_has_diagram(
c,
Expand Down Expand Up @@ -4938,7 +4940,7 @@ def _circuit_diagram_info_(self, args) -> str:
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
CustomOperationAnnotation("a"),
cirq.GlobalPhaseOperation(1j),
cirq.global_phase_operation(1j),
),
),
"""
Expand Down
10 changes: 5 additions & 5 deletions cirq-core/cirq/interop/quirk/cells/scalar_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@


def generate_all_scalar_cell_makers() -> Iterator[CellMaker]:
yield _scalar("NeGate", ops.GlobalPhaseOperation(-1))
yield _scalar("i", ops.GlobalPhaseOperation(1j))
yield _scalar("-i", ops.GlobalPhaseOperation(-1j))
yield _scalar("√i", ops.GlobalPhaseOperation(1j ** 0.5))
yield _scalar("√-i", ops.GlobalPhaseOperation((-1j) ** 0.5))
yield _scalar("NeGate", ops.global_phase_operation(-1))
yield _scalar("i", ops.global_phase_operation(1j))
yield _scalar("-i", ops.global_phase_operation(-1j))
yield _scalar("√i", ops.global_phase_operation(1j ** 0.5))
yield _scalar("√-i", ops.global_phase_operation((-1j) ** 0.5))


def _scalar(identifier: str, operation: 'cirq.Operation') -> CellMaker:
Expand Down
12 changes: 7 additions & 5 deletions cirq-core/cirq/interop/quirk/cells/scalar_cells_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ def test_scalar_operations():
assert_url_to_circuit_returns('{"cols":[["…"]]}', cirq.Circuit())

assert_url_to_circuit_returns(
'{"cols":[["NeGate"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(-1))
'{"cols":[["NeGate"]]}', cirq.Circuit(cirq.global_phase_operation(-1))
)

assert_url_to_circuit_returns('{"cols":[["i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j)))
assert_url_to_circuit_returns('{"cols":[["i"]]}', cirq.Circuit(cirq.global_phase_operation(1j)))

assert_url_to_circuit_returns('{"cols":[["-i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(-1j)))
assert_url_to_circuit_returns(
'{"cols":[["-i"]]}', cirq.Circuit(cirq.global_phase_operation(-1j))
)

assert_url_to_circuit_returns(
'{"cols":[["√i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j ** 0.5))
'{"cols":[["√i"]]}', cirq.Circuit(cirq.global_phase_operation(1j ** 0.5))
)

assert_url_to_circuit_returns(
'{"cols":[["√-i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j ** -0.5))
'{"cols":[["√-i"]]}', cirq.Circuit(cirq.global_phase_operation(1j ** -0.5))
)
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _parallel_gate_op(gate, qubits):
'GateOperation': cirq.GateOperation,
'Gateset': cirq.Gateset,
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseGate': cirq.GlobalPhaseGate,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
'GridInteractionLayer': GridInteractionLayer,
'GridParallelXEBMetadata': GridParallelXEBMetadata,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def _decompose_(self, qubits):

a, b = qubits
return [
ops.GlobalPhaseOperation(self.global_phase),
ops.global_phase_operation(self.global_phase),
ops.MatrixGate(self.single_qubit_operations_before[0]).on(a),
ops.MatrixGate(self.single_qubit_operations_before[1]).on(b),
np.exp(1j * ops.X(a) * ops.X(b) * self.interaction_coefficients[0]),
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@
)

from cirq.ops.global_phase_op import (
GlobalPhaseGate,
GlobalPhaseOperation,
global_phase_operation,
)

from cirq.ops.kraus_channel import (
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _decompose_(self, qubits):
return NotImplemented
result = [PAULI_GATES[p].on(q) for p, q in zip(self.pauli_mask, qubits) if p]
if self.coefficient != 1:
result.append(global_phase_op.GlobalPhaseOperation(self.coefficient))
result.append(global_phase_op.global_phase_operation(self.coefficient))
return result

def _is_parameterized_(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
# diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence,
# we add global phase.
decomposed_circ: List[Any] = [
global_phase_op.GlobalPhaseOperation(np.exp(1j * hat_angles[0]))
global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))
]
for i, bit_flip in _gen_gray_code(n):
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __repr__(self):

def __str__(self) -> str:
qubits = ', '.join(str(e) for e in self.qubits)
return f'{self.gate}({qubits})'
return f'{self.gate}({qubits})' if qubits else str(self.gate)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
g = item if isinstance(item, raw_types.Gate) else item.gate
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'

if isinstance(g, global_phase_op.GlobalPhaseGate):
return self._accept_global_phase_op

daxfohl marked this conversation as resolved.
Show resolved Hide resolved
if g in self._instance_gate_families:
assert item in self._instance_gate_families[g], (
f"{item} instance matches {self._instance_gate_families[g]} but "
Expand Down Expand Up @@ -396,8 +399,6 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
lambda q: cast(circuit_operation.CircuitOperation, op).qubit_map.get(q, q)
)
return self.validate(op_circuit)
elif isinstance(op, global_phase_op.GlobalPhaseOperation):
return self._accept_global_phase_op
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
else:
return False

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_gate_family_eq():
(cirq.SingleQubitGate(), False),
(cirq.X ** 0.5, False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
(cirq.global_phase_operation(1j), False),
],
),
(
Expand All @@ -144,7 +144,7 @@ def test_gate_family_eq():
(CustomX ** 3, True),
(CustomX ** sympy.Symbol('theta'), False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
(cirq.global_phase_operation(1j), False),
],
),
(
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_ops(use_circuit_op, use_global_phase):
)
yield [circuit_op, recursive_circuit_op]
if use_global_phase:
yield cirq.GlobalPhaseOperation(1j)
yield cirq.global_phase_operation(1j)

def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert all(op in gateset for op in cirq.flatten_to_ops(op_tree)) is result
Expand Down
65 changes: 51 additions & 14 deletions cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,69 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A no-qubit global phase operation."""
from typing import Any, Dict, Tuple, TYPE_CHECKING
from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING

import numpy as np

from cirq import value, protocols
from cirq.ops import raw_types
from cirq._compat import deprecated_class
from cirq.ops import gate_operation, raw_types

if TYPE_CHECKING:
import cirq


@value.value_equality(approximate=True)
class GlobalPhaseOperation(raw_types.Operation):
@deprecated_class(deadline='v0.16', fix='Use cirq.global_phase_operation')
class GlobalPhaseOperation(gate_operation.GateOperation):
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None:
if abs(1 - abs(coefficient)) > atol:
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
self.coefficient = coefficient

@property
def qubits(self) -> Tuple['cirq.Qid', ...]:
return ()
gate = GlobalPhaseGate(coefficient, atol)
super().__init__(gate, [])

def with_qubits(self, *new_qubits) -> 'GlobalPhaseOperation':
if new_qubits:
raise ValueError(f'{self!r} applies to 0 qubits but new_qubits={new_qubits!r}.')
return self

@property
def coefficient(self) -> value.Scalar:
return self.gate.coefficient # type: ignore

@coefficient.setter
def coefficient(self, coefficient: value.Scalar):
# coverage: ignore
self.gate._coefficient = coefficient # type: ignore

def __str__(self) -> str:
return str(self.coefficient)

def __repr__(self) -> str:
return f'cirq.GlobalPhaseOperation({self.coefficient!r})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['coefficient'])


@value.value_equality(approximate=True)
class GlobalPhaseGate(raw_types.Gate):
def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None:
if abs(1 - abs(coefficient)) > atol:
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
self._coefficient = coefficient

@property
def coefficient(self) -> value.Scalar:
return self._coefficient

def _value_equality_values_(self) -> Any:
return self.coefficient

def _has_unitary_(self) -> bool:
return True

def __pow__(self, power):
def __pow__(self, power) -> 'cirq.GlobalPhaseGate':
if isinstance(power, (int, float)):
return GlobalPhaseOperation(self.coefficient ** power)
return GlobalPhaseGate(self.coefficient ** power)
return NotImplemented

def _unitary_(self) -> np.ndarray:
Expand All @@ -60,7 +87,7 @@ def _apply_unitary_(self, args) -> np.ndarray:
def _has_stabilizer_effect_(self) -> bool:
return True

def _act_on_(self, args: 'cirq.ActOnArgs'):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
Expand All @@ -78,7 +105,17 @@ def __str__(self) -> str:
return str(self.coefficient)

def __repr__(self) -> str:
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
return f'cirq.GlobalPhaseOperation({self.coefficient!r})'
return f'cirq.GlobalPhaseGate({self.coefficient!r})'

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
return f'cirq.global_phase_operation({self.coefficient!r})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['coefficient'])

def _qid_shape_(self) -> Tuple[int, ...]:
return tuple()


def global_phase_operation(coefficient: value.Scalar, atol: float = 1e-8) -> 'cirq.GateOperation':
return GlobalPhaseGate(coefficient, atol)()
Loading