Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special handling of samples from a DMM channel #565

Merged
merged 19 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,6 @@ def validate_pulse(self, pulse: Pulse) -> None:

Args:
pulse: The pulse to validate.
channel_id: The channel ID used to index the chosen channel
on this device.
"""
if not isinstance(pulse, Pulse):
raise TypeError(
Expand Down
23 changes: 22 additions & 1 deletion pulser-core/pulser/channels/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from dataclasses import dataclass, field
from typing import Literal, Optional

import numpy as np

from pulser.channels.base_channel import Channel
from pulser.pulse import Pulse


@dataclass(init=True, repr=False, frozen=True)
Expand Down Expand Up @@ -51,7 +54,7 @@ class DMM(Channel):
bottom_detuning: Optional[float] = field(default=None, init=True)
addressing: Literal["Global"] = field(default="Global", init=False)
max_abs_detuning: Optional[float] = field(default=None, init=False)
max_amp: float = field(default=0, init=False) # can't be 0
max_amp: float = field(default=0, init=False)
min_retarget_interval: Optional[int] = field(default=None, init=False)
fixed_retarget_t: Optional[int] = field(default=None, init=False)
max_targets: Optional[int] = field(default=None, init=False)
Expand All @@ -72,3 +75,21 @@ def _undefined_fields(self) -> list[str]:
"max_duration",
]
return [field for field in optional if getattr(self, field) is None]

def validate_pulse(self, pulse: Pulse) -> None:
"""Checks if a pulse can be executed in this DMM.

Args:
pulse: The pulse to validate.
"""
super().validate_pulse(pulse)
round_detuning = np.round(pulse.detuning.samples, decimals=6)
if np.any(round_detuning > 0):
raise ValueError("The detuning in a DMM must not be positive.")
if self.bottom_detuning is not None and np.any(
round_detuning < self.bottom_detuning
):
raise ValueError(
"The detuning goes below the bottom detuning "
f"of the DMM ({self.bottom_detuning} rad/µs)."
)
3 changes: 2 additions & 1 deletion pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from pulser.json.utils import get_dataclass_defaults, obj_to_dict
from pulser.register.base_register import BaseRegister, QubitId
from pulser.register.mappable_reg import MappableRegister
from pulser.register.register_layout import COORD_PRECISION, RegisterLayout
from pulser.register.register_layout import RegisterLayout
from pulser.register.traps import COORD_PRECISION

DIMENSIONS = Literal[2, 3]

Expand Down
3 changes: 2 additions & 1 deletion pulser-core/pulser/json/supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@
"TriangularLatticeLayout",
),
"pulser.register.mappable_reg": ("MappableRegister",),
"pulser.register.weight_maps": ("DetuningMap",),
"pulser.devices": tuple(
[dev.name for dev in devices._valid_devices] + ["VirtualDevice"]
),
"pulser.channels": ("Rydberg", "Raman", "Microwave"),
"pulser.channels": ("Rydberg", "Raman", "Microwave", "DMM"),
"pulser.channels.eom": ("BaseEOM", "RydbergEOM", "RydbergBeam"),
"pulser.pulse": ("Pulse",),
"pulser.waveforms": (
Expand Down
139 changes: 14 additions & 125 deletions pulser-core/pulser/register/register_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,28 @@

from __future__ import annotations

import hashlib
from collections.abc import Mapping
from collections.abc import Sequence as abcSequence
from dataclasses import dataclass
from functools import cached_property
from hashlib import sha256
from operator import itemgetter
from typing import Any, Optional, cast
from typing import Any, Optional

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike

from pulser.json.utils import obj_to_dict
from pulser.register._reg_drawer import RegDrawer
from pulser.register.base_register import BaseRegister, QubitId
from pulser.register.mappable_reg import MappableRegister
from pulser.register.register import Register
from pulser.register.register3d import Register3D
from pulser.register.traps import Traps
from pulser.register.weight_maps import DetuningMap

COORD_PRECISION = 6


@dataclass(init=False, repr=False, eq=False, frozen=True)
class RegisterLayout(RegDrawer):
class RegisterLayout(Traps, RegDrawer):
"""A layout of traps out of which registers can be defined.

The traps are always sorted under the same convention: ascending order
Expand All @@ -51,96 +48,10 @@ class RegisterLayout(RegDrawer):
slug: An optional identifier for the layout.
"""

_trap_coordinates: ArrayLike
slug: Optional[str]

def __init__(
self, trap_coordinates: ArrayLike, slug: Optional[str] = None
):
"""Initializes a RegisterLayout."""
array_type_error_msg = ValueError(
"'trap_coordinates' must be an array or list of coordinates."
)

try:
coords_arr = np.array(trap_coordinates, dtype=float)
except ValueError as e:
raise array_type_error_msg from e

shape = coords_arr.shape
if len(shape) != 2:
raise array_type_error_msg

if shape[1] not in (2, 3):
raise ValueError(
f"Each coordinate must be of size 2 or 3, not {shape[1]}."
)

if len(np.unique(trap_coordinates, axis=0)) != shape[0]:
raise ValueError(
"All trap coordinates of a register layout must be unique."
)

object.__setattr__(self, "_trap_coordinates", trap_coordinates)
object.__setattr__(self, "slug", slug)

@property
def traps_dict(self) -> dict:
"""Mapping between trap IDs and coordinates."""
return dict(enumerate(self.coords))

@cached_property # Acts as an attribute in a frozen dataclass
def _coords(self) -> np.ndarray:
coords = np.array(self._trap_coordinates, dtype=float)
# Sorting the coordinates 1st left to right, 2nd bottom to top
rounded_coords = np.round(coords, decimals=COORD_PRECISION)
dims = rounded_coords.shape[1]
sorter = [rounded_coords[:, i] for i in range(dims - 1, -1, -1)]
sorting = np.lexsort(tuple(sorter))
return cast(np.ndarray, rounded_coords[sorting])

@cached_property # Acts as an attribute in a frozen dataclass
def _coords_to_traps(self) -> dict[tuple[float, ...], int]:
return {tuple(coord): id for id, coord in self.traps_dict.items()}

@property
def coords(self) -> np.ndarray:
"""The sorted trap coordinates."""
# Copies to prevent direct access to self._coords
return self._coords.copy()

@property
def number_of_traps(self) -> int:
"""The number of traps in the layout."""
return len(self._coords)

@property
def dimensionality(self) -> int:
"""The dimensionality of the layout (2 or 3)."""
return self._coords.shape[1]

def get_traps_from_coordinates(self, *coordinates: ArrayLike) -> list[int]:
"""Finds the trap ID for a given set of trap coordinates.

Args:
coordinates: The coordinates to return the trap IDs.

Returns:
The list of trap IDs corresponding to the coordinates.
"""
traps = []
rounded_coords = np.round(
np.array(coordinates), decimals=COORD_PRECISION
)
for coord, rounded in zip(coordinates, rounded_coords):
key = tuple(rounded)
if key not in self._coords_to_traps:
raise ValueError(
f"The coordinate '{coord!s}' is not a part of the "
"RegisterLayout."
)
traps.append(self._coords_to_traps[key])
return traps
"""A shorthand for 'sorted_coords'."""
return self.sorted_coords

def define_register(
self, *trap_ids: int, qubit_ids: Optional[abcSequence[QubitId]] = None
Expand Down Expand Up @@ -205,7 +116,7 @@ def define_detuning_map(
if not set(detuning_weights.keys()) <= set(self.traps_dict):
raise ValueError(
"The trap ids of detuning weights have to be integers"
f" between 0 and {self.number_of_traps}."
f" in [0, {self.number_of_traps-1}]."
)
return DetuningMap(
itemgetter(*detuning_weights.keys())(self.traps_dict),
Expand Down Expand Up @@ -311,47 +222,25 @@ def make_mappable_register(
qubit_ids = [f"{prefix}{i}" for i in range(n_qubits)]
return MappableRegister(self, *qubit_ids)

def _safe_hash(self) -> bytes:
# Include dimensionality because the array is flattened with tobytes()
hash = sha256(bytes(self.dimensionality))
hash.update(self.coords.tobytes())
return hash.digest()

def static_hash(self) -> str:
"""Returns the layout's idempotent hash.

Python's standard hash is not idempotent as it changes between
sessions. This hash can be used when an idempotent hash is
required.

Returns:
str: An hexstring encoding the hash.

Note:
This hash will be returned as an hexstring without
the '0x' prefix (unlike what is returned by 'hex()').
"""
return self._safe_hash().hex()
@property
def _hash_object(self) -> hashlib._Hash:
return super()._hash_object

def __eq__(self, other: Any) -> bool:
if not isinstance(other, RegisterLayout):
return False
return self._safe_hash() == other._safe_hash()

def __hash__(self) -> int:
return hash(self._safe_hash())
return super().__eq__(other) and isinstance(other, RegisterLayout)

def __repr__(self) -> str:
return f"RegisterLayout_{self._safe_hash().hex()}"

def __str__(self) -> str:
return self.slug or self.__repr__()
def __hash__(self) -> int:
return hash(self._safe_hash())

def _to_dict(self) -> dict[str, Any]:
# Allows for serialization of subclasses without a special _to_dict()
return obj_to_dict(
self,
self._trap_coordinates,
slug=self.slug,
_module=__name__,
_name="RegisterLayout",
)
Expand Down
Loading