Skip to content

Commit

Permalink
Modify BaseDevice signature for channels (#447)
Browse files Browse the repository at this point in the history
* Replacing Optional with | None

* Removing calls to Device._channels

* Add default_id() to channels

* Replacing `_channels` with `channel_objects`

* Finish UTs

* Changing dict key exclusion method
  • Loading branch information
HGSilveri authored Jan 11, 2023
1 parent 38f0c91 commit dede761
Show file tree
Hide file tree
Showing 12 changed files with 372 additions and 250 deletions.
4 changes: 4 additions & 0 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ def __repr__(self) -> str:
config += f", Basis: '{self.basis}')"
return self.name + config

def default_id(self) -> str:
"""Generates the default ID for indexing this channel in a Device."""
return f"{self.name.lower()}_{self.addressing.lower()}"

def _to_dict(self) -> dict[str, Any]:
params = {
f.name: getattr(self, f.name) for f in fields(self) if f.init
Expand Down
4 changes: 4 additions & 0 deletions pulser-core/pulser/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,7 @@ class Microwave(Channel):
def basis(self) -> Literal["XY"]:
"""The addressed basis name."""
return "XY"

def default_id(self) -> str:
"""Generates the default ID for indexing this channel in a Device."""
return f"mw_{self.addressing.lower()}"
122 changes: 96 additions & 26 deletions pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from __future__ import annotations

import json
import warnings
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass, field, fields
from sys import version_info
from typing import Any, Optional, cast
from typing import Any, cast

import numpy as np
from scipy.spatial.distance import pdist, squareform
Expand Down Expand Up @@ -53,7 +55,12 @@ class BaseDevice(ABC):
Attributes:
name: The name of the device.
dimensions: Whether it supports 2D or 3D arrays.
rybderg_level : The value of the principal quantum number :math:`n`
channel_objects: The Channel subclass instances specifying each
channel in the device.
channel_ids: Custom IDs for each channel object. When defined,
an ID must be given for each channel. If not defined, the IDs are
generated internally based on the channels' names and addressing.
rybderg_level: The value of the principal quantum number :math:`n`
when the Rydberg level used is of the form
:math:`|nS_{1/2}, m_j = +1/2\rangle`.
max_atom_num: Maximum number of atoms supported in an array.
Expand All @@ -71,14 +78,16 @@ class BaseDevice(ABC):
name: str
dimensions: DIMENSIONS
rydberg_level: int
_channels: tuple[tuple[str, Channel], ...]
min_atom_distance: float
max_atom_num: Optional[int]
max_radial_distance: Optional[int]
interaction_coeff_xy: Optional[float] = None
max_atom_num: int | None
max_radial_distance: int | None
interaction_coeff_xy: float | None = None
supports_slm_mask: bool = False
max_layout_filling: float = 0.5
reusable_channels: bool = field(default=False, init=False)
channel_ids: tuple[str, ...] | None = None
channel_objects: tuple[Channel, ...] = field(default_factory=tuple)
_channels: tuple[tuple[str, Channel], ...] = field(default_factory=tuple)

def __post_init__(self) -> None:
def type_check(
Expand All @@ -102,9 +111,6 @@ def type_check(
f"not {self.dimensions}."
)
self._validate_rydberg_level(self.rydberg_level)
for ch_id, ch_obj in self._channels:
type_check("All channel IDs", str, value_override=ch_id)
type_check("All channels", Channel, value_override=ch_obj)

for param in (
"min_atom_distance",
Expand Down Expand Up @@ -136,14 +142,6 @@ def type_check(
if not valid:
raise ValueError(msg)

if any(
ch.basis == "XY" for _, ch in self._channels
) and not isinstance(self.interaction_coeff_xy, float):
raise TypeError(
"When the device has a 'Microwave' channel, "
"'interaction_coeff_xy' must be a 'float',"
f" not '{type(self.interaction_coeff_xy)}'."
)
type_check("supports_slm_mask", bool)
type_check("reusable_channels", bool)

Expand All @@ -154,13 +152,83 @@ def type_check(
f"not {self.max_layout_filling}."
)

if self._channels:
warnings.warn(
"Specifying the channels of a device through the '_channels'"
" argument is deprecated since v0.9.0 and will be removed in"
" v0.10.0. Instead, use 'channel_objects' to specify the"
" channels and, optionally, 'channel_ids' to specify custom"
" channel IDs.",
DeprecationWarning,
stacklevel=2,
)
if self.channel_objects or self.channel_ids:
raise ValueError(
"'_channels' can't be specified when 'channel_objects'"
" or 'channel_ids' are also provided."
)
ch_objs = []
ch_ids = []
for ch_id, ch_obj in self._channels:
ch_ids.append(ch_id)
ch_objs.append(ch_obj)
object.__setattr__(self, "channel_ids", tuple(ch_ids))
object.__setattr__(self, "channel_objects", tuple(ch_objs))
object.__setattr__(self, "_channels", ())

for ch_obj in self.channel_objects:
type_check("All channels", Channel, value_override=ch_obj)

if self.channel_ids is not None:
if not (
isinstance(self.channel_ids, (tuple, list))
and all(isinstance(el, str) for el in self.channel_ids)
):
raise TypeError(
"When defined, 'channel_ids' must be a tuple or a list "
"of strings."
)
if len(self.channel_ids) != len(set(self.channel_ids)):
raise ValueError(
"When defined, 'channel_ids' can't have "
"repeated elements."
)
if len(self.channel_ids) != len(self.channel_objects):
raise ValueError(
"When defined, the number of channel IDs must"
" match the number of channel objects."
)
else:
# Make the channel IDs from the default IDs
ids_counter: Counter = Counter()
ids = []
for ch_obj in self.channel_objects:
id = ch_obj.default_id()
ids_counter.update([id])
if ids_counter[id] > 1:
# If there is more than one with the same ID
id += f"_{ids_counter[id]}"
ids.append(id)
object.__setattr__(self, "channel_ids", tuple(ids))

if any(
ch.basis == "XY" for ch in self.channel_objects
) and not isinstance(self.interaction_coeff_xy, float):
raise TypeError(
"When the device has a 'Microwave' channel, "
"'interaction_coeff_xy' must be a 'float',"
f" not '{type(self.interaction_coeff_xy)}'."
)

def to_tuple(obj: tuple | list) -> tuple:
if isinstance(obj, (tuple, list)):
obj = tuple(to_tuple(el) for el in obj)
return obj

# Turns mutable lists into immutable tuples
object.__setattr__(self, "_channels", to_tuple(self._channels))
for param in self._params():
if "channel" in param:
object.__setattr__(self, param, to_tuple(getattr(self, param)))

@property
@abstractmethod
Expand All @@ -170,12 +238,12 @@ def _optional_parameters(self) -> tuple[str, ...]:
@property
def channels(self) -> dict[str, Channel]:
"""Dictionary of available channels on this device."""
return dict(self._channels)
return dict(zip(cast(tuple, self.channel_ids), self.channel_objects))

@property
def supported_bases(self) -> set[str]:
"""Available electronic transitions for control and measurement."""
return {ch.basis for ch in self.channels.values()}
return {ch.basis for ch in self.channel_objects}

@property
def interaction_coeff(self) -> float:
Expand Down Expand Up @@ -364,11 +432,13 @@ def _to_dict(self) -> dict[str, Any]:

@abstractmethod
def _to_abstract_repr(self) -> dict[str, Any]:
ex_params = ("_channels", "channel_objects", "channel_ids")
params = self._params()
for p in ex_params:
params.pop(p, None)
ch_list = []
for ch_name, ch_obj in params.pop("_channels"):
for ch_name, ch_obj in self.channels.items():
ch_list.append(ch_obj._to_abstract_repr(ch_name))

return {"version": "1", "channels": ch_list, **params}

def to_abstract_repr(self) -> str:
Expand Down Expand Up @@ -411,7 +481,7 @@ class Device(BaseDevice):

def __post_init__(self) -> None:
super().__post_init__()
for ch_id, ch_obj in self._channels:
for ch_id, ch_obj in self.channels.items():
if ch_obj.is_virtual():
_sep = "', '"
raise ValueError(
Expand Down Expand Up @@ -466,7 +536,7 @@ def _specs(self, for_docs: bool = False) -> str:
]

ch_lines = []
for name, ch in self._channels:
for name, ch in self.channels.items():
if for_docs:
ch_lines += [
f" - ID: '{name}'",
Expand Down Expand Up @@ -539,8 +609,8 @@ class VirtualDevice(BaseDevice):
on the same pulse sequence.
"""
min_atom_distance: float = 0
max_atom_num: Optional[int] = None
max_radial_distance: Optional[int] = None
max_atom_num: int | None = None
max_radial_distance: int | None = None
supports_slm_mask: bool = True
reusable_channels: bool = True

Expand Down
90 changes: 39 additions & 51 deletions pulser-core/pulser/devices/_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,33 @@
max_radial_distance=50,
min_atom_distance=4,
supports_slm_mask=True,
_channels=(
(
"rydberg_global",
Rydberg.Global(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 2.5,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
channel_objects=(
Rydberg.Global(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 2.5,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
(
"rydberg_local",
Rydberg.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
Rydberg.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
(
"raman_local",
Raman.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
Raman.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
clock_period=4,
min_duration=16,
max_duration=2**26,
),
),
)
Expand All @@ -73,23 +64,20 @@
max_atom_num=100,
max_radial_distance=60,
min_atom_distance=5,
_channels=(
(
"rydberg_global",
Rydberg.Global(
max_abs_detuning=2 * np.pi * 4,
max_amp=2 * np.pi * 3,
clock_period=4,
min_duration=16,
max_duration=2**26,
mod_bandwidth=4,
eom_config=RydbergEOM(
limiting_beam=RydbergBeam.RED,
max_limiting_amp=40 * 2 * np.pi,
intermediate_detuning=700 * 2 * np.pi,
mod_bandwidth=24,
controlled_beams=(RydbergBeam.BLUE,),
),
channel_objects=(
Rydberg.Global(
max_abs_detuning=2 * np.pi * 4,
max_amp=2 * np.pi * 3,
clock_period=4,
min_duration=16,
max_duration=2**26,
mod_bandwidth=4,
eom_config=RydbergEOM(
limiting_beam=RydbergBeam.RED,
max_limiting_amp=40 * 2 * np.pi,
intermediate_detuning=700 * 2 * np.pi,
mod_bandwidth=24,
controlled_beams=(RydbergBeam.BLUE,),
),
),
),
Expand Down
12 changes: 6 additions & 6 deletions pulser-core/pulser/devices/_mock_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
min_atom_distance=0.0,
interaction_coeff_xy=3700.0,
supports_slm_mask=True,
_channels=(
("rydberg_global", Rydberg.Global(None, None, max_duration=None)),
("rydberg_local", Rydberg.Local(None, None, max_duration=None)),
("raman_global", Raman.Global(None, None, max_duration=None)),
("raman_local", Raman.Local(None, None, max_duration=None)),
("mw_global", Microwave.Global(None, None, max_duration=None)),
channel_objects=(
Rydberg.Global(None, None, max_duration=None),
Rydberg.Local(None, None, max_duration=None),
Raman.Global(None, None, max_duration=None),
Raman.Local(None, None, max_duration=None),
Microwave.Global(None, None, max_duration=None),
),
)
13 changes: 9 additions & 4 deletions pulser-core/pulser/json/abstract_repr/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,17 @@ def _deserialize_device_object(obj: dict[str, Any]) -> Device | VirtualDevice:
device_cls: Type[Device] | Type[VirtualDevice] = (
VirtualDevice if obj["is_virtual"] else Device
)
_channels = tuple(
(ch["id"], _deserialize_channel(ch)) for ch in obj["channels"]
ch_ids = []
ch_objs = []
for ch in obj["channels"]:
ch_ids.append(ch["id"])
ch_objs.append(_deserialize_channel(ch))
params: dict[str, Any] = dict(
channel_ids=tuple(ch_ids), channel_objects=tuple(ch_objs)
)
params: dict[str, Any] = {"_channels": _channels}
ex_params = ("_channels", "channel_objects", "channel_ids")
for param in dataclasses.fields(device_cls):
if not param.init or param.name == "_channels":
if not param.init or param.name in ex_params:
continue
if param.name == "pre_calibrated_layouts":
key = "pre_calibrated_layouts"
Expand Down
Loading

0 comments on commit dede761

Please sign in to comment.