Skip to content

Commit

Permalink
Highlight EOM buffers in Sequence drawings (#475)
Browse files Browse the repository at this point in the history
* Display eom intervals

* Showing EOM start and end buffers

* Add test drawing sequence with eom

* formating, changing color, gradient of shade

* Changing np_version for VisibleDeprecationWarning

* Handling no disable_eom, no eom start buffer

* Create EOMSegment

* Changing definition end buffer
  • Loading branch information
a-corni authored Mar 2, 2023
1 parent 41e3299 commit ed64bb1
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 3 deletions.
129 changes: 127 additions & 2 deletions pulser-core/pulser/sequence/_seq_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from scipy.interpolate import CubicSpline

Expand All @@ -44,12 +45,69 @@
]


class EOMSegment:
"""The class to mark an EOM slot."""

def __init__(self, ti: int | None = None, tf: int | None = None) -> None:
"""Class is defined from its start and end value."""
self.ti = ti
self.tf = tf
self.color = "steelblue"
self.alpha = 0.3

@property
def isempty(self) -> bool:
"""Defines if the class is empty."""
return self.ti is None or self.tf is None

@property
def nvspan(self) -> int:
"""Defines the number of points in the slot."""
return cast(int, self.tf) - cast(int, self.ti)

def draw(self, ax: Axes) -> None:
"""Draws a rectangle between the start and end value."""
if not self.isempty:
ax.axvspan(
self.ti,
self.tf,
color=self.color,
alpha=self.alpha,
zorder=-100,
)

def smooth_draw(self, ax: Axes, decreasing: bool = False) -> None:
"""Draws a rectangle with an increasing/decreasing opacity."""
if not self.isempty:
for i in range(self.nvspan):
ax.axvspan(
cast(int, self.ti) + i,
cast(int, self.ti) + i + 1,
facecolor=self.color,
alpha=self.alpha
* (
decreasing + (-1) ** decreasing * (i + 1) / self.nvspan
),
zorder=-100,
)
ax.axvline(
self.tf if decreasing else self.ti,
ax.get_ylim()[0],
ax.get_ylim()[1],
color=self.color,
alpha=self.alpha / 2.0,
)


@dataclass
class ChannelDrawContent:
"""The contents for drawing a single channel."""

samples: ChannelSamples
target: dict[Union[str, tuple[int, int]], Any]
eom_intervals: list[EOMSegment]
eom_start_buffers: list[EOMSegment]
eom_end_buffers: list[EOMSegment]
interp_pts: dict[str, list[list[float]]] = field(default_factory=dict)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -122,10 +180,51 @@ def gather_data(seq: pulser.sequence.Sequence, gather_output: bool) -> dict:
# List of interpolation points
interp_pts: defaultdict[str, list[list[float]]] = defaultdict(list)
target: dict[Union[str, tuple[int, int]], Any] = {}
# Extracting the EOM Buffers
eom_intervals = [
EOMSegment(eom_interval[0], eom_interval[1])
for eom_interval in sch.get_eom_mode_intervals()
]
nb_eom_intervals = len(eom_intervals)
eom_start_buffers = [EOMSegment() for _ in range(nb_eom_intervals)]
eom_end_buffers = [EOMSegment() for _ in range(nb_eom_intervals)]
in_eom_mode = False
eom_block_n = -1
# Last eom interval is extended if eom mode not disabled at the end
if nb_eom_intervals > 0 and seq.get_duration() == eom_intervals[-1].tf:
eom_intervals[-1].tf = total_duration
# sampling the channel schedule
samples = sch.get_samples()
extended_samples = samples.extend_duration(total_duration)
for slot in sch:
if slot.ti == -1:
target["initial"] = slot.targets
continue
else:
# If slot is not the first element in schedule
if sch.in_eom_mode(slot):
# EOM mode starts
if not in_eom_mode:
in_eom_mode = True
eom_block_n += 1
elif in_eom_mode:
# Buffer when EOM mode is disabled and next slot has 0 amp
in_eom_mode = False
if extended_samples.amp[slot.ti] == 0:
eom_end_buffers[eom_block_n] = EOMSegment(
slot.ti, slot.tf
)
if (
eom_block_n + 1 < nb_eom_intervals
and slot.tf == eom_intervals[eom_block_n + 1].ti
and extended_samples.det[slot.tf - 1]
== sch.eom_blocks[eom_block_n + 1].detuning_off
):
# Buffer if next is eom and final det matches det_off
eom_start_buffers[eom_block_n + 1] = EOMSegment(
slot.ti, slot.tf
)

if slot.type == "target":
target[(slot.ti, slot.tf - 1)] = slot.targets
continue
Expand All @@ -140,9 +239,12 @@ def gather_data(seq: pulser.sequence.Sequence, gather_output: bool) -> dict:
interp_pts[wf_type] += pts.tolist()

# Store everything
samples = sch.get_samples()
data[ch] = ChannelDrawContent(
samples.extend_duration(total_duration), target
extended_samples,
target,
eom_intervals,
eom_start_buffers,
eom_end_buffers,
)
if interp_pts:
data[ch].interp_pts = dict(interp_pts)
Expand Down Expand Up @@ -217,6 +319,7 @@ def phase_str(phi: float) -> str:
ph_box = dict(boxstyle="round", facecolor="ghostwhite")
area_ph_box = dict(boxstyle="round", facecolor="ghostwhite", alpha=0.7)
slm_box = dict(boxstyle="round", alpha=0.4, facecolor="grey", hatch="//")
eom_box = dict(boxstyle="round", facecolor="lightsteelblue")

# Draw masked register
if draw_register:
Expand Down Expand Up @@ -311,6 +414,9 @@ def phase_str(phi: float) -> str:
for ch, axes in ch_axes.items():
ch_obj = seq.declared_channels[ch]
ch_data = data[ch]
ch_eom_intervals = data[ch].eom_intervals
ch_eom_start_buffers = data[ch].eom_start_buffers
ch_eom_end_buffers = data[ch].eom_end_buffers
basis = ch_obj.basis
ys = ch_data.get_input_curves()
ys_mod = [()] * 3
Expand Down Expand Up @@ -527,6 +633,25 @@ def phase_str(phi: float) -> str:
bbox=ph_box,
)

# Draw the EOM intervals
for ch_eom_start_buffer, ch_eom_interval, ch_eom_end_buffer in zip(
ch_eom_start_buffers, ch_eom_intervals, ch_eom_end_buffers
):
for ax in axes:
ch_eom_start_buffer.smooth_draw(ax, decreasing=False)
ch_eom_interval.draw(ax)
ch_eom_end_buffer.smooth_draw(ax, decreasing=True)
tgt_txt_x = ch_eom_start_buffer.ti or ch_eom_interval.ti
tgt_txt_y = axes[0].get_ylim()[1]
axes[0].text(
tgt_txt_x,
tgt_txt_y,
"EOM",
fontsize=12,
ha="left",
va="top",
bbox=eom_box,
)
# Draw the SLM mask
if seq._slm_mask_targets and seq._slm_mask_time:
tf_m = seq._slm_mask_time[1]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_register_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_creation(layout, layout3d):
np_version = tuple(map(int, np.__version__.split(".")))
context_manager = (
pytest.warns(np.VisibleDeprecationWarning)
if np_version < (1, 22)
if np_version < (1, 24)
else contextlib.nullcontext()
)
with context_manager, pytest.raises(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,10 @@ def test_eom_mode(mod_device):
duration, 0.0, new_eom_block.detuning_off, last_pulse_slot.type.phase
)

# Test drawing in eom mode
with patch("matplotlib.pyplot.show"):
seq.draw()


@pytest.mark.parametrize(
"initial_instruction, non_zero_detuning_off",
Expand Down

0 comments on commit ed64bb1

Please sign in to comment.