Skip to content

Commit

Permalink
Apply comments of removing dynamical decoupling model class.
Browse files Browse the repository at this point in the history
  • Loading branch information
babacry committed May 8, 2024
1 parent 427227d commit 3ae49b1
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 168 deletions.
1 change: 0 additions & 1 deletion cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@
from cirq.transformers import (
AbstractInitialMapper,
add_dynamical_decoupling,
DynamicalDecouplingModel,
align_left,
align_right,
CompilationTargetGateset,
Expand Down
1 change: 0 additions & 1 deletion cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _symmetricalqidpair(qids):
import sympy

return {
'DynamicalDecouplingModel': cirq.DynamicalDecouplingModel,
'AmplitudeDampingChannel': cirq.AmplitudeDampingChannel,
'AnyIntegerPowerGateFamily': cirq.AnyIntegerPowerGateFamily,
'AnyUnitaryGateFamily': cirq.AnyUnitaryGateFamily,
Expand Down

This file was deleted.

This file was deleted.

5 changes: 1 addition & 4 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@

from cirq.transformers.drop_negligible_operations import drop_negligible_operations

from cirq.transformers.dynamical_decoupling import (
add_dynamical_decoupling,
DynamicalDecouplingModel,
)
from cirq.transformers.dynamical_decoupling import add_dynamical_decoupling

from cirq.transformers.eject_z import eject_z

Expand Down
158 changes: 60 additions & 98 deletions cirq-core/cirq/transformers/dynamical_decoupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,143 +14,105 @@

"""Transformer pass that adds dynamical decoupling operations to a circuit."""

import enum
from functools import reduce
from typing import Any, Dict, Optional, Tuple
from typing import Dict, Optional, Sequence, Tuple, Union

from cirq.transformers import transformer_api
import cirq
from cirq import value
import numpy as np


@enum.unique
class _DynamicalDecouplingSchema(enum.Enum):
"""Supported schemes of dynamical decoupling."""

XX_PAIR = 'XX_PAIR'
X_XINV = 'X_XINV'
YY_PAIR = 'YY_PAIR'
Y_YINV = 'Y_YINV'


def _repeat_sequence(base_sequence: list['cirq.Gate'], num_idle_moments: int):
def _repeat_sequence(
base_sequence: Sequence['cirq.Gate'], num_idle_moments: int
) -> Sequence['cirq.Gate']:
"""Returns the longest possible dynamical decoupling sequence."""
repeat_times = num_idle_moments // len(base_sequence)
return base_sequence * repeat_times
return list(base_sequence) * repeat_times


def _generate_dd_sequence_from_schema(
schema: _DynamicalDecouplingSchema, num_idle_moments: int = 2
) -> list['cirq.Gate']:
def _get_dd_sequence_from_schema_name(schema: str) -> Sequence['cirq.Gate']:
"""Gets dynamical decoupling sequence from a schema name."""
dd_sequence: Sequence['cirq.Gate']
match schema:
case _DynamicalDecouplingSchema.XX_PAIR:
return _repeat_sequence([cirq.X, cirq.X], num_idle_moments)
case _DynamicalDecouplingSchema.X_XINV:
return _repeat_sequence([cirq.X, cirq.X**-1], num_idle_moments)
case _DynamicalDecouplingSchema.YY_PAIR:
return _repeat_sequence([cirq.Y, cirq.Y], num_idle_moments)
case _DynamicalDecouplingSchema.Y_YINV:
return _repeat_sequence([cirq.Y, cirq.Y**-1], num_idle_moments)
case 'XX_PAIR':
dd_sequence = (cirq.X, cirq.X)
case 'X_XINV':
dd_sequence = (cirq.X, cirq.X**-1)
case 'YY_PAIR':
dd_sequence = (cirq.Y, cirq.Y)
case 'Y_YINV':
dd_sequence = (cirq.Y, cirq.Y**-1)
case _:
raise ValueError('Invalid schema name.')
return dd_sequence


def _validate_dd_sequence(dd_sequence: Sequence['cirq.Gate']) -> Tuple[bool, Optional[str]]:
"""Validates a given dynamical decoupling sequence.
Args:
dd_sequence: Input dynamical sequence to be validated.
def _validate_dd_sequence(dd_sequence: list['cirq.Gate']) -> None:
Returns:
A tuple containing:
- is_valid (bool): True if the dd sequence is valid, False otherwise.
- error_message (str): An error message if the dd sequence is invalid, else None.
"""
if len(dd_sequence) < 2:
raise ValueError('Invalid dynamical decoupling sequence. Expect more than one gates.')
return False, 'Invalid dynamical decoupling sequence. Expect more than one gates.'
matrices = [cirq.unitary(gate) for gate in dd_sequence]
product = reduce(np.matmul, matrices)

if not cirq.equal_up_to_global_phase(product, np.eye(2)):
raise ValueError(
"Invalid dynamical decoupling sequence. Expect sequence production equals identity"
f" up to a global phase, got {product}.".replace('\n', ' ')
return False, (
'Invalid dynamical decoupling sequence. Expect sequence production equals'
f' identity up to a global phase, got {product}.'.replace('\n', ' ')
)


@value.value_equality
class DynamicalDecouplingModel:
"""Dynamical decoupling model that generates dynamical decoupling operation sequences."""

def __init__(
self,
schema: Optional[_DynamicalDecouplingSchema] = None,
base_dd_sequence: Optional[list['cirq.Gate']] = None,
):
if not schema and not base_dd_sequence:
raise ValueError(
'Specify either schema or base_dd_sequence to construct a valid'
' DynamicalDecouplingModel.'
)
self.schema = schema
self.base_dd_sequence = base_dd_sequence
if base_dd_sequence:
_validate_dd_sequence(base_dd_sequence)

def generate_dd_sequence(self, num_idle_moments: int = 2) -> list['cirq.Gate']:
"""Returns the longest possible dynamical decoupling sequence."""
if num_idle_moments <= 0:
return []
if self.schema:
dd_sequence = _generate_dd_sequence_from_schema(self.schema, num_idle_moments)
elif self.base_dd_sequence:
dd_sequence = _repeat_sequence(self.base_dd_sequence, num_idle_moments)
return dd_sequence

@classmethod
def from_schema(cls, schema: str):
"""Create dynamical decoupling model according to a given schema."""
if not schema in _DynamicalDecouplingSchema.__members__:
raise ValueError("Invalid schema name.")
return cls(schema=_DynamicalDecouplingSchema[schema])

@classmethod
def from_base_dd_sequence(cls, base_dd_sequence: list['cirq.Gate']):
"""Create dynamical decoupling model according to a base sequence."""
return cls(base_dd_sequence=base_dd_sequence)

def _json_dict_(self) -> Dict[str, Any]:
d: Dict[str, Any] = {}
if self.schema:
d['schema'] = self.schema.name
if self.base_dd_sequence:
d['base_dd_sequence'] = self.base_dd_sequence
return d

@classmethod
def _from_json_dict_(cls, schema=None, base_dd_sequence=None, **kwargs):
if schema:
return cls(schema=_DynamicalDecouplingSchema[schema])
if base_dd_sequence:
return cls(base_dd_sequence=base_dd_sequence)

def _value_equality_values_(self) -> Any:
return self.schema, self.base_dd_sequence
return True, None


@transformer_api.transformer
def add_dynamical_decoupling(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
dd_model: DynamicalDecouplingModel = DynamicalDecouplingModel.from_schema("X_XINV"),
schema: Union[str, Sequence['cirq.Gate']] = 'X_XINV',
) -> 'cirq.Circuit':
"""Add dynamical decoupling gate operations to a given circuit.
"""Adds dynamical decoupling gate operations to idle moments of a given circuit.
This transformer preserves the moment structure of the circuit.
Args:
circuit: Input circuit to transform.
context: `cirq.TransformerContext` storing common configurable options for transformers.
dd_model: Dynamical decoupling model that defines the schema to generate dynamical
decoupling sequences.
schema: Dynamical decoupling schema name or a dynamical decoupling sequence.
If a schema is specified, provided dynamical decouping sequence will be used.
Otherwise, customized dynamical decoupling sequence will be applied.
Return:
Returns:
A copy of the input circuit with dynamical decoupling operations.
Raises:
ValueError: If schema is not valid.
"""
last_busy_moment_by_qubits: Dict['cirq.Qid', int] = {q: 0 for q in circuit.all_qubits()}
insert_into: list[Tuple[int, 'cirq.OP_TREE']] = []

if isinstance(schema, str):
try:
base_dd_sequence = _get_dd_sequence_from_schema_name(schema)
except ValueError:
raise
else:
is_valid, error_message = _validate_dd_sequence(schema)
if is_valid:
base_dd_sequence = schema
else:
raise ValueError(error_message)

for moment_id, moment in enumerate(circuit):
for q in moment.qubits:
insert_gates = dd_model.generate_dd_sequence(
num_idle_moments=moment_id - last_busy_moment_by_qubits[q] - 1
insert_gates = _repeat_sequence(
base_dd_sequence, num_idle_moments=moment_id - last_busy_moment_by_qubits[q] - 1
)
for idx, gate in enumerate(insert_gates):
insert_into.append((last_busy_moment_by_qubits[q] + idx + 1, gate.on(q)))
Expand Down
Loading

0 comments on commit 3ae49b1

Please sign in to comment.