Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: future cancellation resulted in reset, retry endpoints fialing #817

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 12 additions & 24 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
from threading import Event
from typing import TYPE_CHECKING

import alembic
Expand Down Expand Up @@ -170,15 +171,9 @@ def reset_media_item(item: "MediaItem"):
item.reset()
session.commit()

def reset_streams(item: "MediaItem", active_stream_hash: str = None):
def reset_streams(item: "MediaItem"):
"""Reset streams associated with a MediaItem."""
with db.Session() as session:
item.store_state()
item = session.merge(item)
if active_stream_hash:
stream = session.query(Stream).filter(Stream.infohash == active_stream_hash).first()
if stream:
blacklist_stream(item, stream, session)

session.execute(
delete(StreamRelation).where(StreamRelation.parent_id == item._id)
Expand All @@ -187,20 +182,11 @@ def reset_streams(item: "MediaItem", active_stream_hash: str = None):
session.execute(
delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
)
item.active_stream = {}
session.commit()

def clear_streams(item: "MediaItem"):
"""Clear all streams for a media item."""
with db.Session() as session:
item = session.merge(item)
session.execute(
delete(StreamRelation).where(StreamRelation.parent_id == item._id)
)
session.execute(
delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
)
session.commit()
reset_streams(item)

def clear_streams_by_id(media_item_id: int):
"""Clear all streams for a media item by the MediaItem _id."""
Expand Down Expand Up @@ -357,7 +343,7 @@ def store_item(item: "MediaItem"):
finally:
session.close()

def run_thread_with_db_item(fn, service, program, input_id: int = None):
def run_thread_with_db_item(fn, service, program, input_id, cancellation_event: Event):
from program.media.item import MediaItem
if input_id:
with db.Session() as session:
Expand All @@ -377,11 +363,12 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None):
logger.log("PROGRAM", f"Service {service.__name__} emitted {item} from input item {input_item} of type {type(item).__name__}, backing off.")
program.em.remove_id_from_queues(input_item._id)

input_item.store_state()
session.commit()
if not cancellation_event.is_set():
input_item.store_state()
session.commit()

session.expunge_all()
yield res
return res
else:
# Indexing returns a copy of the item, was too lazy to create a copy attr func so this will do for now
indexed_item = next(fn(input_item), None)
Expand All @@ -392,9 +379,10 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None):
indexed_item.store_state()
session.delete(input_item)
indexed_item = session.merge(indexed_item)
session.commit()
logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...")
yield indexed_item._id
if not cancellation_event.is_set():
session.commit()
logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...")
return indexed_item._id
return
else:
# Content services
Expand Down
51 changes: 25 additions & 26 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def __init__(self, item: dict | None) -> None:
#Post processing
self.subtitles = item.get("subtitles", [])

def store_state(self) -> None:
new_state = self._determine_state()
def store_state(self, given_state=None) -> None:
new_state = given_state if given_state else self._determine_state()
if self.last_state and self.last_state != new_state:
sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id})
self.last_state = new_state
Expand All @@ -145,6 +145,10 @@ def is_stream_blacklisted(self, stream: Stream):
session.refresh(self, attribute_names=['blacklisted_streams'])
return stream in self.blacklisted_streams

def blacklist_active_stream(self):
stream = next(stream for stream in self.streams if stream.infohash == self.active_stream["infohash"])
self.blacklist_stream(stream)

def blacklist_stream(self, stream: Stream):
value = blacklist_stream(self, stream)
if value:
Expand Down Expand Up @@ -321,20 +325,23 @@ def get_aliases(self) -> dict:
def __hash__(self):
return hash(self._id)

def reset(self, soft_reset: bool = False):
def reset(self):
"""Reset item attributes."""
if self.type == "show":
for season in self.seasons:
for episode in season.episodes:
episode._reset(soft_reset)
season._reset(soft_reset)
episode._reset()
season._reset()
elif self.type == "season":
for episode in self.episodes:
episode._reset(soft_reset)
self._reset(soft_reset)
self.store_state()
episode._reset()
self._reset()
if self.title:
self.store_state(States.Indexed)
else:
self.store_state(States.Requested)

def _reset(self, soft_reset):
def _reset(self):
"""Reset item attributes for rescraping."""
if self.symlink_path:
if Path(self.symlink_path).exists():
Expand All @@ -351,16 +358,8 @@ def _reset(self, soft_reset):
self.set("folder", None)
self.set("alternative_folder", None)

if not self.active_stream:
self.active_stream = {}
if not soft_reset:
if self.active_stream.get("infohash", False):
reset_streams(self, self.active_stream["infohash"])
else:
if self.active_stream.get("infohash", False):
stream = next((stream for stream in self.streams if stream.infohash == self.active_stream["infohash"]), None)
if stream:
self.blacklist_stream(stream)
reset_streams(self)
self.active_stream = {}

self.set("active_stream", {})
self.set("symlinked", False)
Expand All @@ -371,7 +370,7 @@ def _reset(self, soft_reset):
self.set("symlinked_times", 0)
self.set("scraped_times", 0)

logger.debug(f"Item {self.log_string} reset for rescraping")
logger.debug(f"Item {self.log_string} has been reset")

@property
def log_string(self):
Expand Down Expand Up @@ -456,10 +455,10 @@ def _determine_state(self):
return States.Requested
return States.Unknown

def store_state(self) -> None:
def store_state(self, given_state: States =None) -> None:
for season in self.seasons:
season.store_state()
super().store_state()
season.store_state(given_state)
super().store_state(given_state)

def __repr__(self):
return f"Show:{self.log_string}:{self.state.name}"
Expand Down Expand Up @@ -527,10 +526,10 @@ class Season(MediaItem):
"polymorphic_load": "inline",
}

def store_state(self) -> None:
def store_state(self, given_state: States = None) -> None:
for episode in self.episodes:
episode.store_state()
super().store_state()
episode.store_state(given_state)
super().store_state(given_state)

def __init__(self, item):
self.type = "season"
Expand Down
3 changes: 2 additions & 1 deletion src/program/symlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def run(self, item: Union[Movie, Show, Season, Episode]):
if not self._should_submit(items):
if item.symlinked_times == 5:
logger.debug(f"Soft resetting {item.log_string} because required files were not found")
item.reset(True)
item.blacklist_active_stream()
item.reset()
yield item
next_attempt = self._calculate_next_attempt(item)
logger.debug(f"Waiting for {item.log_string} to become available, next attempt in {round((next_attempt - datetime.now()).total_seconds())} seconds")
Expand Down
2 changes: 1 addition & 1 deletion src/routers/secure/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str) -> SetTorrentRDRe
# downloader = request.app.program.services.get(Downloader).service
# with db.Session() as session:
# item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
# item.reset(True)
# item.reset()
# downloader.download_cached(item, hash)
# request.app.program.add_to_queue(item)
# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"}
25 changes: 13 additions & 12 deletions src/utils/event_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import threading
import traceback

from datetime import datetime
Expand All @@ -8,8 +9,7 @@

from loguru import logger
from pydantic import BaseModel
from sqlalchemy.orm.exc import StaleDataError
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor

from utils.sse_manager import sse_manager
from program.db.db import db
Expand Down Expand Up @@ -37,6 +37,7 @@ def __init__(self):
self._futures: list[Future] = []
self._queued_events: list[Event] = []
self._running_events: list[Event] = []
self._canceled_futures: list[Future] = []
self.mutex = Lock()

def _find_or_create_executor(self, service_cls) -> ThreadPoolExecutor:
Expand Down Expand Up @@ -71,7 +72,7 @@ def _process_future(self, future, service):
service (type): The service class associated with the future.
"""
try:
result = next(future.result(), None)
result = future.result()
if future in self._futures:
self._futures.remove(future)
sse_manager.publish_event("event_update", self.get_event_updates())
Expand All @@ -81,10 +82,10 @@ def _process_future(self, future, service):
item_id, timestamp = result, datetime.now()
if item_id:
self.remove_event_from_running(item_id)
if future.cancellation_event.is_set():
logger.debug(f"Future with Item ID: {item_id} was cancelled discarding results...")
return
self.add_event(Event(emitted_by=service, item_id=item_id, run_at=timestamp))
except (StaleDataError, CancelledError):
# Expected behavior when cancelling tasks or when the item was removed
return
except Exception as e:
logger.error(f"Error in future for {future}: {e}")
logger.exception(traceback.format_exc())
Expand Down Expand Up @@ -166,8 +167,10 @@ def submit_job(self, service, program, event=None):
log_message += f" with Item ID: {item_id}"
logger.debug(log_message)

cancellation_event = threading.Event()
executor = self._find_or_create_executor(service)
future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id)
future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id, cancellation_event)
future.cancellation_event = cancellation_event
if event:
future.event = event
self._futures.append(future)
Expand All @@ -186,27 +189,25 @@ def cancel_job(self, item_id: int, suppress_logs=False):
item_id, related_ids = get_item_ids(session, item_id)
ids_to_cancel = set([item_id] + related_ids)

futures_to_remove = []
for future in self._futures:
future_item_id = None
future_related_ids = []

if hasattr(future, 'event') and hasattr(future.event, 'item'):
if hasattr(future, 'event') and hasattr(future.event, 'item_id'):
future_item = future.event.item_id
future_item_id, future_related_ids = get_item_ids(session, future_item)

if future_item_id in ids_to_cancel or any(rid in ids_to_cancel for rid in future_related_ids):
self.remove_id_from_queues(future_item)
futures_to_remove.append(future)
if not future.done() and not future.cancelled():
try:
future.cancellation_event.set()
future.cancel()
self._canceled_futures.append(future)
except Exception as e:
if not suppress_logs:
logger.error(f"Error cancelling future for {future_item.log_string}: {str(e)}")

for future in futures_to_remove:
self._futures.remove(future)

logger.debug(f"Canceled jobs for Item ID {item_id} and its children.")

Expand Down
Loading