Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix: update time interval exceeded boolean #147

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/useq/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate:
phases = tplan.phases if hasattr(tplan, "phases") else [tplan]
tot_duration = 0.0
for phase in phases:
phase_duration, exceeded = _time_phase_duration(phase, s_per_timepoint)
phase_duration, exceeded = _time_phase_duration(phase, s_per_timepoint, seq)
tot_duration += phase_duration
t_interval_exceeded = t_interval_exceeded or exceeded
else:
Expand All @@ -147,7 +147,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate:


def _time_phase_duration(
phase: SinglePhaseTimePlan, s_per_timepoint: float
phase: SinglePhaseTimePlan, s_per_timepoint: float, seq: useq.MDASequence
) -> tuple[float, bool]:
"""Calculate duration for a single time plan phase."""
time_interval_s = phase.interval.total_seconds()
Expand All @@ -158,6 +158,28 @@ def _time_phase_duration(
# to actually acquire the data
time_interval_s = s_per_timepoint

axis = list(seq.axis_order)
# if there are no position and grid axes, then the time interval is not
# exceeded
if Axis.POSITION not in axis and Axis.GRID not in axis:
time_interval_exceeded = False
# if there are both position and grid axes, then the time interval, is not
# exceeded if the position and grid axes are before the time axis
elif Axis.POSITION in axis and Axis.GRID in axis:
if axis.index(Axis.POSITION) < axis.index(Axis.TIME) and axis.index(
Axis.GRID
) < axis.index(Axis.TIME):
time_interval_exceeded = False
# if there is only one of position or grid axes, then the time interval is
# not exceeded if that axis is before the time axis
elif Axis.POSITION in axis:
if axis.index(Axis.POSITION) < axis.index(Axis.TIME):
time_interval_exceeded = False
elif axis.index(Axis.GRID) < axis.index(Axis.TIME):
time_interval_exceeded = False

# TODO: add cases with a single pos or a single fov grid

tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint
return tot_duration, time_interval_exceeded

Expand Down
32 changes: 32 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,35 @@ def test_z_plan_num_position():

def test_channel_str():
assert MDAEvent(channel="DAPI") == MDAEvent(channel={"config": "DAPI"})


def test_time_interval_exceeded():
main_seq = MDASequence(
axis_order="tc",
channels=[{"config": "DAPI", "exposure": 100}],
time_plan=TIntervalLoops(loops=10, interval=0),
)
assert not main_seq.estimate_duration().time_interval_exceeded

p_seq = main_seq.replace(axis_order="tpc", stage_position=[(1, 2, 3)])
assert p_seq.estimate_duration().time_interval_exceeded

p_seq = p_seq.replace(axis_order="ptc")
assert not p_seq.estimate_duration().time_interval_exceeded

g_seq = main_seq.replace(
axis_order="tgc", grid_plan=GridRelative(rows=1, columns=2)
)
assert g_seq.estimate_duration().time_interval_exceeded

g_seq = g_seq.replace(axis_order="gtc")
assert not g_seq.estimate_duration().time_interval_exceeded

pg_seq = g_seq.replace(axis_order="tpgc", stage_position=[(1, 2, 3)])
assert pg_seq.estimate_duration().time_interval_exceeded

pg_seq = pg_seq.replace(axis_order="ptcg")
assert pg_seq.estimate_duration().time_interval_exceeded

pg_seq = pg_seq.replace(axis_order="pgtc")
assert not pg_seq.estimate_duration().time_interval_exceeded