Skip to content

Commit

Permalink
Update phase jump behaviour (#416)
Browse files Browse the repository at this point in the history
* Making `phase_jump_time` a Channel property

* Accounting for fall time in the phase jump buffer

* Adding note explaining the phase jump behaviour

* Removing `phase_jump_time` from switch_device()

* Addressing review comment

* Adding simple test for Pulse.fall_time()

* Incorporate the EOM features

* Skip phase jump with the 'no-delay' protocol

* Refactoring _Schedule.add_pulse()
  • Loading branch information
HGSilveri authored Dec 16, 2022
1 parent 07e17e7 commit f701f55
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 158 deletions.
27 changes: 10 additions & 17 deletions pulser-core/pulser/channels/base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ class Channel(ABC):
max_abs_detuning: Maximum possible detuning (in rad/µs), in absolute
value.
max_amp: Maximum pulse amplitude (in rad/µs).
phase_jump_time: Time taken to change the phase between consecutive
pulses (in ns).
min_retarget_interval: Minimum time required between the ends of two
target instructions (in ns).
fixed_retarget_t: Time taken to change the target (in ns).
Expand All @@ -82,7 +80,6 @@ class Channel(ABC):
addressing: Literal["Global", "Local"]
max_abs_detuning: Optional[float]
max_amp: Optional[float]
phase_jump_time: int = 0
min_retarget_interval: Optional[int] = None
fixed_retarget_t: Optional[int] = None
max_targets: Optional[int] = None
Expand Down Expand Up @@ -125,15 +122,13 @@ def __post_init__(self) -> None:
parameters = [
"max_amp",
"max_abs_detuning",
"phase_jump_time",
"clock_period",
"min_duration",
"max_duration",
"mod_bandwidth",
]
non_negative = [
"max_abs_detuning",
"phase_jump_time",
"min_retarget_interval",
"fixed_retarget_t",
]
Expand Down Expand Up @@ -203,6 +198,14 @@ def rise_time(self) -> int:
else:
return 0

@property
def phase_jump_time(self) -> int:
"""Time taken to change the phase between consecutive pulses (in ns).
Corresponds to two times the rise time.
"""
return self.rise_time * 2

def is_virtual(self) -> bool:
"""Whether the channel is virtual (i.e. partially defined)."""
return bool(self._undefined_fields())
Expand All @@ -226,7 +229,6 @@ def Local(
cls,
max_abs_detuning: Optional[float],
max_amp: Optional[float],
phase_jump_time: int = 0,
min_retarget_interval: int = 0,
fixed_retarget_t: int = 0,
max_targets: Optional[int] = None,
Expand All @@ -238,8 +240,6 @@ def Local(
max_abs_detuning: Maximum possible detuning (in rad/µs), in
absolute value.
max_amp: Maximum pulse amplitude (in rad/µs).
phase_jump_time: Time taken to change the phase between
consecutive pulses (in ns).
min_retarget_interval: Minimum time required between two
target instructions (in ns).
fixed_retarget_t: Time taken to change the target (in ns).
Expand All @@ -261,7 +261,6 @@ def Local(
"Local",
max_abs_detuning,
max_amp,
phase_jump_time,
min_retarget_interval,
fixed_retarget_t,
max_targets,
Expand All @@ -273,7 +272,6 @@ def Global(
cls,
max_abs_detuning: Optional[float],
max_amp: Optional[float],
phase_jump_time: int = 0,
**kwargs: Any,
) -> Channel:
"""Initializes the channel with global addressing.
Expand All @@ -282,8 +280,6 @@ def Global(
max_abs_detuning: Maximum possible detuning (in rad/µs), in
absolute value.
max_amp: Maximum pulse amplitude (in rad/µs).
phase_jump_time: Time taken to change the phase between
consecutive pulses (in ns).
Keyword Args:
clock_period(int, default=4): The duration of a clock cycle
Expand All @@ -296,9 +292,7 @@ def Global(
mod_bandwidth(Optional[float], default=None): The modulation
bandwidth at -3dB (50% reduction), in MHz.
"""
return cls(
"Global", max_abs_detuning, max_amp, phase_jump_time, **kwargs
)
return cls("Global", max_abs_detuning, max_amp, **kwargs)

def validate_duration(self, duration: int) -> int:
"""Validates and adapts the duration of an instruction on this channel.
Expand Down Expand Up @@ -474,8 +468,7 @@ def __repr__(self) -> str:
f"{self.max_abs_detuning}"
f"{' rad/µs' if self.max_abs_detuning else ''}, "
f"Max Amplitude: {self.max_amp}"
f"{' rad/µs' if self.max_amp else ''}, "
f"Phase Jump Time: {self.phase_jump_time} ns"
f"{' rad/µs' if self.max_amp else ''}"
)
if self.addressing == "Local":
config += (
Expand Down
1 change: 0 additions & 1 deletion pulser-core/pulser/devices/_device_datacls.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,6 @@ def _specs(self, for_docs: bool = False) -> str:
+ r"- Maximum :math:`|\delta|`:"
+ f" {ch.max_abs_detuning:.4g} rad/µs"
),
f"\t- Phase Jump Time: {ch.phase_jump_time} ns",
]
if ch.addressing == "Local":
ch_lines += [
Expand Down
4 changes: 0 additions & 4 deletions pulser-core/pulser/devices/_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
Rydberg.Global(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 2.5,
phase_jump_time=0,
clock_period=4,
min_duration=16,
max_duration=2**26,
Expand All @@ -43,7 +42,6 @@
Rydberg.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
phase_jump_time=0,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
Expand All @@ -57,7 +55,6 @@
Raman.Local(
max_abs_detuning=2 * np.pi * 20,
max_amp=2 * np.pi * 10,
phase_jump_time=0,
min_retarget_interval=220,
fixed_retarget_t=0,
max_targets=1,
Expand All @@ -82,7 +79,6 @@
Rydberg.Global(
max_abs_detuning=2 * np.pi * 4,
max_amp=2 * np.pi * 3,
phase_jump_time=500,
clock_period=4,
min_duration=16,
max_duration=2**26,
Expand Down
89 changes: 53 additions & 36 deletions pulser-core/pulser/sequence/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,44 +260,32 @@ def add_pulse(
last = self[channel][-1]
t0 = last.tf
current_max_t = max(t0, *phase_barrier_ts)
for ch, ch_schedule in self.items():
if protocol == "no-delay" or ch == channel:
continue
this_chobj = self[ch].channel_obj
in_eom_mode = self[ch].in_eom_mode()
for op in ch_schedule[::-1]:
if not isinstance(op.type, Pulse):
if op.tf + 2 * this_chobj.rise_time <= current_max_t:
# No pulse behind 'op' needing a delay
break
elif (
op.tf
+ op.type.fall_time(this_chobj, in_eom_mode=in_eom_mode)
<= current_max_t
):
break
elif op.targets & last.targets or protocol == "wait-for-all":
current_max_t = op.tf + op.type.fall_time(
this_chobj, in_eom_mode=in_eom_mode
)
break

# Buffer to add between pulses of different phase
phase_jump_buffer = 0
try:
# Gets the last pulse on the channel
last_pulse_slot = self[channel].last_pulse_slot()
last_pulse = cast(Pulse, last_pulse_slot.type)
# Checks if the current pulse changes the phase
if last_pulse.phase != pulse.phase:
# Subtracts the time that has already elapsed since the
# last pulse from the phase_jump_time
phase_jump_buffer = self[
channel
].channel_obj.phase_jump_time - (t0 - last_pulse_slot.tf)
except RuntimeError:
# No previous pulse
pass
if protocol != "no-delay":
current_max_t = self._find_add_delay(
current_max_t, channel, protocol
)
try:
# Gets the last pulse on the channel
last_pulse_slot = self[channel].last_pulse_slot()
last_pulse = cast(Pulse, last_pulse_slot.type)
# Checks if the current pulse changes the phase
if last_pulse.phase != pulse.phase:
# Subtracts the time that has already elapsed since the
# last pulse from the phase_jump_time and adds the
# fall_time to let the last pulse ramp down
ch_obj = self[channel].channel_obj
phase_jump_buffer = (
ch_obj.phase_jump_time
+ last_pulse.fall_time(
ch_obj, in_eom_mode=self[channel].in_eom_mode()
)
- (t0 - last_pulse_slot.tf)
)
except RuntimeError:
# No previous pulse
pass

delay_duration = max(current_max_t - t0, phase_jump_buffer)
if delay_duration > 0:
Expand Down Expand Up @@ -368,3 +356,32 @@ def wait_for_fall(self, channel: str) -> None:
# If there is a fall time, a delay is added to account for it
if fall_time > 0:
self.add_delay(self[channel].adjust_duration(fall_time), channel)

def _find_add_delay(self, t0: int, channel: str, protocol: str) -> int:
current_max_t = t0
for ch, ch_schedule in self.items():
if ch == channel:
continue
this_chobj = self[ch].channel_obj
in_eom_mode = self[ch].in_eom_mode()
for op in ch_schedule[::-1]:
if not isinstance(op.type, Pulse):
if op.tf + 2 * this_chobj.rise_time <= current_max_t:
# No pulse behind 'op' needing a delay
break
elif (
op.tf
+ op.type.fall_time(this_chobj, in_eom_mode=in_eom_mode)
<= current_max_t
):
break
elif (
op.targets & self[channel][-1].targets
or protocol == "wait-for-all"
):
current_max_t = op.tf + op.type.fall_time(
this_chobj, in_eom_mode=in_eom_mode
)
break

return current_max_t
43 changes: 12 additions & 31 deletions pulser-core/pulser/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from pulser.pulse import Pulse
from pulser.register.base_register import BaseRegister, QubitId
from pulser.register.mappable_reg import MappableRegister
from pulser.sampler import sample
from pulser.sequence._basis_ref import _QubitRef
from pulser.sequence._call import _Call
from pulser.sequence._schedule import _ChannelSchedule, _Schedule, _TimeSlot
Expand Down Expand Up @@ -436,9 +435,7 @@ def switch_device(
)

# Channel match
sample_seq = sample(self)
channel_match: dict[str, Any] = {}
alert_phase_jump = False
strict_error_message = ""
ch_type_er_mess = ""
for o_d_ch_name, o_d_ch_obj in self.declared_channels.items():
Expand Down Expand Up @@ -492,29 +489,12 @@ def switch_device(
)
continue

ch_samples = sample_seq.channel_samples[o_d_ch_name]
ch_sample_phase = ch_samples.phase
# Find if there is phase change between pulses or not
phase_is_constant = True
if ch_sample_phase.size != 0:
phase_is_constant = bool(
np.all(ch_sample_phase == ch_sample_phase[0])
)
# Phase_jump_time and clock_period check
phase_jump_time_check = (
o_d_ch_obj.phase_jump_time == n_d_ch_obj.phase_jump_time
)
clock_period_check = (
o_d_ch_obj.clock_period == n_d_ch_obj.clock_period
)
if clock_period_check and (
phase_is_constant or phase_jump_time_check
):
# Clock_period check
if o_d_ch_obj.clock_period == n_d_ch_obj.clock_period:
channel_match[o_d_ch_name] = n_d_ch_id
alert_phase_jump = not phase_jump_time_check
break
strict_error_message = strict_error_message or (
base_msg + " with the same phase_jump_time & clock_period."
base_msg + " with the same clock_period."
)

if None in channel_match.values():
Expand Down Expand Up @@ -543,13 +523,6 @@ def switch_device(
else:
sw_channel_args[1] = channel_match[sw_channel_args[0]]

if strict and alert_phase_jump:
warnings.warn(
"The phase_jump_time of the matching channel on "
+ "the the new device is different, take it into account"
+ " for the upcoming pulses.",
stacklevel=2,
)
new_seq.declare_channel(*sw_channel_args, **sw_channel_kw_args)
return new_seq

Expand Down Expand Up @@ -848,7 +821,8 @@ def add_eom_pulse(
Note:
When the phase between pulses is changed, the necessary buffer
time for a phase jump will still be enforced.
time for a phase jump will still be enforced (unless
``protocol='no-delay'``).
Args:
channel: The name of the channel to add the pulse to.
Expand Down Expand Up @@ -911,6 +885,13 @@ def add(
- ``'wait-for-all'``: Before adding the pulse, adds a delay
that idles the channel until the end of the other channels'
latest pulse.
Note:
When the phase of the pulse to add is different than the phase of
the previous pulse on the channel, a delay between the two pulses
might be automatically added to ensure the channel's
`phase_jump_time` is respected. To override this behaviour, use
the ``'no-delay'`` protocol.
"""
self._validate_channel(channel, block_eom_mode=True)
self._add(pulse, channel, protocol)
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def mod_device() -> Device:
2 * np.pi * 20,
2 * np.pi * 10,
max_targets=2,
phase_jump_time=0,
fixed_retarget_t=0,
clock_period=4,
min_retarget_interval=220,
Expand All @@ -77,7 +76,6 @@ def mod_device() -> Device:
2 * np.pi * 20,
2 * np.pi * 10,
max_targets=2,
phase_jump_time=0,
fixed_retarget_t=0,
min_retarget_interval=220,
clock_period=4,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,17 @@ def test_repr():
)
r1 = (
"Raman.Local(Max Absolute Detuning: None, Max Amplitude: "
"2 rad/µs, Phase Jump Time: 0 ns, Minimum retarget time: 1000 ns, "
"2 rad/µs, Minimum retarget time: 1000 ns, "
"Fixed retarget time: 200 ns, Max targets: 4, Clock period: 4 ns, "
"Minimum pulse duration: 16 ns, Basis: 'digital')"
)
assert raman.__str__() == r1

ryd = Rydberg.Global(50, None, phase_jump_time=300, mod_bandwidth=4)
ryd = Rydberg.Global(50, None, mod_bandwidth=4)
r2 = (
"Rydberg.Global(Max Absolute Detuning: 50 rad/µs, "
"Max Amplitude: None, Phase Jump Time: 300 ns, "
"Clock period: 1 ns, Minimum pulse duration: 1 ns, "
"Max Amplitude: None, Clock period: 1 ns, "
"Minimum pulse duration: 1 ns, "
"Maximum pulse duration: 100000000 ns, "
"Modulation Bandwidth: 4 MHz, Basis: 'ground-rydberg')"
)
Expand Down
Loading

0 comments on commit f701f55

Please sign in to comment.