diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1dfc4f411..be813601e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,4 +72,4 @@ jobs: - name: Test validation with legacy jsonschema run: | pip install jsonschema==4.17.3 - pytest tests/test_abstract_repr.py + pytest tests/test_abstract_repr.py -W ignore::DeprecationWarning diff --git a/pulser-core/pulser/sequence/_seq_drawer.py b/pulser-core/pulser/sequence/_seq_drawer.py index f45f8b0f4..e26c9d2c7 100644 --- a/pulser-core/pulser/sequence/_seq_drawer.py +++ b/pulser-core/pulser/sequence/_seq_drawer.py @@ -118,7 +118,12 @@ class ChannelDrawContent: phase_modulated: bool = False def __post_init__(self) -> None: - self.curves_on = {"amplitude": True, "detuning": False, "phase": False} + is_dmm = isinstance(self.samples, DMMSamples) + self.curves_on = { + "amplitude": not is_dmm, + "detuning": is_dmm, + "phase": False, + } @property def _samples_from_curves(self) -> dict[str, str]: @@ -533,13 +538,20 @@ def phase_str(phi: float) -> str: time_scale = 1e3 if total_duration > 1e4 else 1 for ch in sampled_seq.channels: data[ch].phase_modulated = phase_modulated - if np.count_nonzero(data[ch].samples.det) > 0: - data[ch].curves_on["detuning"] = not phase_modulated - data[ch].curves_on["phase"] = phase_modulated - if (phase_modulated or draw_phase_curve) and np.count_nonzero( - data[ch].samples.phase - ) > 0: - data[ch].curves_on["phase"] = True + curves_on = data[ch].curves_on.copy() + _, det_samples_, phase_samples_ = data[ch].get_input_curves() + non_zero_det = np.count_nonzero(det_samples_) > 0 + non_zero_phase = np.count_nonzero(phase_samples_) > 0 + curves_on["detuning"] = non_zero_det ^ ( + phase_modulated and non_zero_phase + ) + curves_on["phase"] = ( + phase_modulated or draw_phase_curve + ) and non_zero_phase + + if any(curve_on for curve_on in curves_on.values()): + # The channel is not empty + data[ch].curves_on = curves_on # Boxes for qubit and phase text q_box = dict(boxstyle="round", facecolor="orange") @@ -730,6 +742,7 @@ def phase_str(phi: float) -> str: ) target_regions = [] # [[start1, [targets1], end1],...] + tgt_txt_ymax = ax_lims[0][1] * 0.92 for coords in ch_data.target: targets = list(ch_data.target[coords]) tgt_strs = [str(q) for q in targets] @@ -737,7 +750,7 @@ def phase_str(phi: float) -> str: tgt_strs = ["⚄"] elif ch_obj.addressing == "Global": tgt_strs = ["GLOBAL"] - tgt_txt_y = max_amp * 1.1 - 0.25 * (len(tgt_strs) - 1) + tgt_txt_y = tgt_txt_ymax - 0.25 * (len(tgt_strs) - 1) tgt_str = "\n".join(tgt_strs) if coords == "initial": x = t_min + final_t * 0.005 @@ -745,7 +758,7 @@ def phase_str(phi: float) -> str: if ch_obj.addressing == "Global": axes[0].text( x, - amp_top * 0.98, + tgt_txt_ymax * 1.065, tgt_strs[0], fontsize=13 if tgt_strs == ["GLOBAL"] else 17, rotation=90 if tgt_strs == ["GLOBAL"] else 0, @@ -767,7 +780,7 @@ def phase_str(phi: float) -> str: msg = r"$\phi=$" + phase_str(phase) axes[0].text( 0, - max_amp * 1.1, + tgt_txt_ymax, msg, ha="left", fontsize=12, @@ -798,7 +811,7 @@ def phase_str(phi: float) -> str: x = tf + final_t * 0.01 * (wrd_len + 1) axes[0].text( x, - max_amp * 1.1, + tgt_txt_ymax, msg, ha="left", fontsize=12, @@ -826,7 +839,7 @@ def phase_str(phi: float) -> str: msg = "\u27F2 " + phase_str(delta) axes[0].text( t_ - final_t * 8e-3, - max_amp * 1.1, + tgt_txt_ymax, msg, ha="right", fontsize=14, @@ -875,7 +888,7 @@ def phase_str(phi: float) -> str: msg = f"Basis: {data['measurement']}" if len(axes) == 1: mid_ax = axes[0] - mid_point = (amp_top + amp_bottom) / 2 + mid_point = sum(ax_lims[0]) / 2 fontsize = 12 else: mid_ax = axes[-1]