Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cooldown_period to OrTrigger #1004

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ APScheduler, see the :doc:`migration section <migration>`.
acquire the same schedules at once
- Changed ``SQLAlchemyDataStore`` to automatically create the explicitly specified
schema if it's missing (PR by @zhu0629)
- Added cooldown_period to OrTrigger
(#453 <https://github.com/agronholm/apscheduler/issues/453>_; PR by @HomerusJa)

**4.0.0a5**

Expand Down
92 changes: 78 additions & 14 deletions src/apscheduler/triggers/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,30 +121,94 @@ class OrTrigger(BaseCombiningTrigger):
fire times.

:param triggers: triggers to combine
:param cooldown_period: minimum time between two consecutive fires (in seconds, or as
timedelta)
"""

cooldown_period: timedelta = attrs.field(converter=as_timedelta, default=0)
_last_fire_time: datetime | None = attrs.field(default=None, eq=False, init=False)
max_iterations: int | None = 10000

def _get_next_valid_fire_time(self) -> tuple[datetime | None, list[int]]:
"""
Find the next valid fire time that respects the cooldown period.

:raises MaxIterationsReached: If the maximum number of iterations is reached
:returns: A tuple of (fire_time, trigger_indices) where fire_time is the next valid
fire time (or None if no valid time exists) and trigger_indices is a list
of indices of triggers that produced this fire time.
"""
for _ in range(self.max_iterations):
earliest_time = min(
(
fire_time
for fire_time in self._next_fire_times
if fire_time is not None
),
default=None,
)
if earliest_time is None:
return None, []

# Find all triggers that produced this fire time
trigger_indices = [
i
for i, fire_time in enumerate(self._next_fire_times)
if fire_time == earliest_time
]

# Check if we need to respect cooldown period
if (
self.cooldown_period > timedelta(0)
and self._last_fire_time is not None
and earliest_time - self._last_fire_time < self.cooldown_period
):
# Get next fire times for all triggers that would have fired
for i in trigger_indices:
self._next_fire_times[i] = self.triggers[i].next()
continue

return earliest_time, trigger_indices
else:
raise MaxIterationsReached

def next(self) -> datetime | None:
# Fill out the fire times on the first run
# Initialize fire times if needed
if not self._next_fire_times:
self._next_fire_times = [t.next() for t in self.triggers]
self._last_fire_time = None

# Find out the earliest of the fire times
earliest_time: datetime | None = min(
(fire_time for fire_time in self._next_fire_times if fire_time is not None),
default=None,
)
if earliest_time is not None:
# Generate new fire times for the trigger(s) that generated the earliest
# fire time
for i, fire_time in enumerate(self._next_fire_times):
if fire_time == earliest_time:
self._next_fire_times[i] = self.triggers[i].next()
# Get next valid fire time and affected triggers
try:
fire_time, trigger_indices = self._get_next_valid_fire_time()
except RecursionError:
raise MaxIterationsReached

if fire_time is not None:
# Update last fire time and get next fire times for triggered sources
self._last_fire_time = fire_time
for i in trigger_indices:
self._next_fire_times[i] = self.triggers[i].next()

return earliest_time
return fire_time

def __setstate__(self, state: dict[str, Any]) -> None:
require_state_version(self, state, 1)
super().__setstate__(state)
self.cooldown_period = state["cooldown_period"]
self._last_fire_time = state["last_fire_time"]
self.max_iterations = state["max_iterations"]

def __getstate__(self) -> dict[str, Any]:
state = super().__getstate__()
state["cooldown_period"] = self.cooldown_period
state["last_fire_time"] = self._last_fire_time
state["max_iterations"] = self.max_iterations
return state

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.triggers})"
return (
f"{self.__class__.__name__}"
f"({self.triggers}, cooldown_period={self.cooldown_period.total_seconds()}"
f", max_iterations={self.max_iterations})"
)
49 changes: 47 additions & 2 deletions tests/triggers/test_combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,56 @@ def test_two_interval_triggers(self, timezone, serializer):
# The end time of the 6 second interval has been reached
assert trigger.next() is None

def test_cooldown_period(self, timezone, serializer):
start_time = datetime(2020, 5, 16, 14, 17, 30, 254212, tzinfo=timezone)
trigger = OrTrigger(
[
IntervalTrigger(seconds=4, start_time=start_time),
IntervalTrigger(seconds=6, start_time=start_time),
],
cooldown_period=1,
)
if serializer:
trigger = serializer.deserialize(serializer.serialize(trigger))

assert trigger.next() == start_time
assert trigger.next() == start_time + timedelta(seconds=4)
assert trigger.next() == start_time + timedelta(seconds=6)
assert trigger.next() == start_time + timedelta(seconds=8)
assert trigger.next() == start_time + timedelta(
seconds=12
) # No second trigger -> cooldown
assert trigger.next() == start_time + timedelta(seconds=16)
assert trigger.next() == start_time + timedelta(seconds=18)

def test_max_iterations(self, timezone, serializer):
start_time = datetime(2020, 5, 16, 14, 17, 30, 254212, tzinfo=timezone)
trigger = OrTrigger(
[
IntervalTrigger(seconds=1, start_time=start_time),
IntervalTrigger(seconds=1, start_time=start_time),
],
cooldown_period=100,
# Max iterations should be reached before the cooldown period
max_iterations=10,
)
if serializer:
trigger = serializer.deserialize(serializer.serialize(trigger))

# The triggers will keep firing after each other indefinitely
assert trigger.next() == start_time
pytest.raises(MaxIterationsReached, trigger.next)

def test_repr(self, timezone):
date1 = datetime(2020, 5, 16, 14, 17, 30, 254212, tzinfo=timezone)
date2 = datetime(2020, 5, 18, 15, 1, 53, 940564, tzinfo=timezone)
trigger = OrTrigger([DateTrigger(date1), DateTrigger(date2)])
trigger = OrTrigger(
[DateTrigger(date1), DateTrigger(date2)],
cooldown_period=1,
max_iterations=10000,
)
print(repr(trigger))
assert repr(trigger) == (
"OrTrigger([DateTrigger('2020-05-16 14:17:30.254212+02:00'), "
"DateTrigger('2020-05-18 15:01:53.940564+02:00')])"
"DateTrigger('2020-05-18 15:01:53.940564+02:00')], cooldown_period=1.0, max_iterations=10000)"
)
Loading