Skip to content

Commit

Permalink
fix: switch to generator for reset/retry endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
dreulavelle committed Oct 9, 2024
1 parent 24904fc commit bf4fc0e
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 82 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ services:
condition: service_healthy

riven_postgres:
image: postgres:16.3-alpine3.20
image: postgres:17.0-alpine3.20
container_name: riven-db
environment:
PGDATA: /var/lib/postgresql/data/pgdata
Expand Down
35 changes: 16 additions & 19 deletions src/controllers/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,44 +201,41 @@ async def get_items_by_imdb_ids(request: Request, imdb_ids: str):
return {"success": True, "items": [item.to_extended_dict() for item in items]}

@router.post(
"/reset",
summary="Reset Media Items",
description="Reset media items with bases on item IDs",
"/reset",
summary="Reset Media Items",
description="Reset media items with bases on item IDs",
)
async def reset_items(
request: Request, ids: str
):
async def reset_items(request: Request, ids: str):
ids = handle_ids(ids)
try:
media_items = get_media_items_by_ids(ids)
if not media_items or len(media_items) != len(ids):
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
for media_item in media_items:
media_items_generator = get_media_items_by_ids(ids)
for media_item in media_items_generator:
try:
request.app.program.em.cancel_job(media_item)
clear_streams(media_item)
reset_media_item(media_item)
except Exception as e:
except ValueError as e:
logger.error(f"Failed to reset item with id {media_item._id}: {str(e)}")
continue
except Exception as e:
logger.error(f"Unexpected error while resetting item with id {media_item._id}: {str(e)}")
continue
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"success": True, "message": f"Reset items with id {ids}"}

@router.post(
"/retry",
summary="Retry Media Items",
description="Retry media items with bases on item IDs",
"/retry",
summary="Retry Media Items",
description="Retry media items with bases on item IDs",
)
async def retry_items(request: Request, ids: str):
ids = handle_ids(ids)
try:
media_items = get_media_items_by_ids(ids)
if not media_items or len(media_items) != len(ids):
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
for media_item in media_items:
media_items_generator = get_media_items_by_ids(ids)
for media_item in media_items_generator:
request.app.program.em.cancel_job(media_item)
await asyncio.sleep(0.1) # Ensure cancellation is processed
await asyncio.sleep(0.1) # Ensure cancellation is processed
request.app.program.em.add_item(media_item)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
Expand Down
17 changes: 15 additions & 2 deletions src/program/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,31 @@
# cursor.execute("SET statement_timeout = 300000")
# cursor.close()

db = SQLAlchemy(settings_manager.settings.database.host, engine_options=engine_options)
db_host = settings_manager.settings.database.host
db = SQLAlchemy(db_host, engine_options=engine_options)

script_location = data_dir_path / "alembic/"


if not os.path.exists(script_location):
os.makedirs(script_location)

alembic = Alembic(db, script_location)
alembic.init(script_location)


def create_database_if_not_exists():
"""Create the database if it doesn't exist."""
db_name = db_host.split('/')[-1]
db_base_host = '/'.join(db_host.split('/')[:-1])
try:
temp_db = SQLAlchemy(db_base_host, engine_options=engine_options)
with temp_db.engine.connect() as connection:
connection.execution_options(isolation_level="AUTOCOMMIT").execute(text(f"CREATE DATABASE {db_name}"))
return True
except Exception as e:
logger.error(f"Failed to create database {db_name}: {e}")
return False

# https://stackoverflow.com/questions/61374525/how-do-i-check-if-alembic-migrations-need-to-be-generated
def need_upgrade_check() -> bool:
"""Check if there are any pending migrations."""
Expand Down
88 changes: 41 additions & 47 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import alembic
from sqlalchemy import delete, func, insert, select, text, union_all
from sqlalchemy.orm import Session, aliased, joinedload
from sqlalchemy.orm import Session, aliased, selectinload

from program.libraries.symlink import fix_broken_symlinks
from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation
Expand All @@ -21,41 +21,42 @@
def get_media_items_by_ids(media_item_ids: list[int]):
"""Retrieve multiple MediaItems by a list of MediaItem _ids using the _get_item_from_db method."""
from program.media.item import Episode, MediaItem, Movie, Season, Show
items = []

def get_item(session, media_item_id, item_type):
match item_type:
case "movie":
return session.execute(
select(Movie)
.where(MediaItem._id == media_item_id)
).unique().scalar_one()
case "show":
return session.execute(
select(Show)
.where(MediaItem._id == media_item_id)
.options(selectinload(Show.seasons).selectinload(Season.episodes))
).unique().scalar_one()
case "season":
return session.execute(
select(Season)
.where(Season._id == media_item_id)
.options(selectinload(Season.episodes))
).unique().scalar_one()
case "episode":
return session.execute(
select(Episode)
.where(Episode._id == media_item_id)
).unique().scalar_one()
case _:
return None

with db.Session() as session:
for media_item_id in media_item_ids:
item_type = session.execute(select(MediaItem.type).where(MediaItem._id==media_item_id)).scalar_one()
item_type = session.execute(select(MediaItem.type).where(MediaItem._id == media_item_id)).scalar_one()
if not item_type:
continue
item = None
match item_type:
case "movie":
item = session.execute(
select(Movie)
.where(MediaItem._id == media_item_id)
).unique().scalar_one()
case "show":
item = session.execute(
select(Show)
.where(MediaItem._id == media_item_id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
).unique().scalar_one()
case "season":
item = session.execute(
select(Season)
.where(Season._id == media_item_id)
.options(joinedload(Season.episodes))
).unique().scalar_one()
case "episode":
item = session.execute(
select(Episode)
.where(Episode._id == media_item_id)
).unique().scalar_one()
item = get_item(session, media_item_id, item_type)
if item:
items.append(item)

return items
yield item

def get_parent_items_by_ids(media_item_ids: list[int]):
"""Retrieve multiple MediaItems of type 'movie' or 'show' by a list of MediaItem _ids."""
Expand Down Expand Up @@ -312,36 +313,29 @@ def _get_item_from_db(session, item: "MediaItem"):
if not _ensure_item_exists_in_db(item):
return None
session.expire_on_commit = False
type = _get_item_type_from_db(item)
match type:
match item.type:
case "movie":
r = session.execute(
return session.execute(
select(Movie)
.where(MediaItem.imdb_id == item.imdb_id)
).unique().scalar_one()
return r
).unique().scalar_one_or_none()
case "show":
r = session.execute(
return session.execute(
select(Show)
.where(MediaItem.imdb_id == item.imdb_id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
).unique().scalar_one()
return r
).unique().scalar_one_or_none()
case "season":
r = session.execute(
return session.execute(
select(Season)
.where(Season._id == item._id)
.options(joinedload(Season.episodes))
).unique().scalar_one()
return r
).unique().scalar_one_or_none()
case "episode":
r = session.execute(
return session.execute(
select(Episode)
.where(Episode._id == item._id)
).unique().scalar_one()
return r
).unique().scalar_one_or_none()
case _:
logger.error(f"_get_item_from_db Failed to create item from type: {type}")
logger.error(f"_get_item_from_db Failed to create item from type: {item.type}")
return None

def _check_for_and_run_insertion_required(session, item: "MediaItem") -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class MediaItem(db.Model):
scraped_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True)
scraped_times: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, default=0)
active_stream: Mapped[Optional[dict]] = mapped_column(sqlalchemy.JSON, nullable=True)
streams: Mapped[list[Stream]] = relationship(secondary="StreamRelation", back_populates="parents", lazy="select", cascade="all")
blacklisted_streams: Mapped[list[Stream]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_parents", lazy="select", cascade="all")
streams: Mapped[list[Stream]] = relationship(secondary="StreamRelation", back_populates="parents", lazy="selectin", cascade="all")
blacklisted_streams: Mapped[list[Stream]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_parents", lazy="selectin", cascade="all")
symlinked: Mapped[Optional[bool]] = mapped_column(sqlalchemy.Boolean, default=False)
symlinked_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True)
symlinked_times: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, default=0)
Expand All @@ -59,7 +59,7 @@ class MediaItem(db.Model):
update_folder: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
overseerr_id: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, nullable=True)
last_state: Mapped[Optional[States]] = mapped_column(sqlalchemy.Enum(States), default=States.Unknown)
subtitles: Mapped[list[Subtitle]] = relationship(Subtitle, back_populates="parent", lazy="joined", cascade="all, delete-orphan")
subtitles: Mapped[list[Subtitle]] = relationship(Subtitle, back_populates="parent", lazy="selectin", cascade="all, delete-orphan")

__mapper_args__ = {
"polymorphic_identity": "mediaitem",
Expand Down
4 changes: 2 additions & 2 deletions src/program/media/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class Stream(db.Model):
rank: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=False)
lev_ratio: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=False)

parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamRelation", back_populates="streams")
blacklisted_parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_streams")
parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamRelation", back_populates="streams", lazy="selectin")
blacklisted_parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_streams", lazy="selectin")

__table_args__ = (
Index('ix_stream_infohash', 'infohash'),
Expand Down
8 changes: 6 additions & 2 deletions src/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sqlalchemy import func, select, text

import program.db.db_functions as DB
from program.db.db import db, run_migrations, vacuum_and_analyze_index_maintenance
from program.db.db import create_database_if_not_exists, db, run_migrations, vacuum_and_analyze_index_maintenance


class Program(threading.Thread):
Expand Down Expand Up @@ -136,7 +136,11 @@ def start(self):

if not self.validate_database():
# We should really make this configurable via frontend...
return
logger.log("PROGRAM", "Database not found, trying to create database")
if not create_database_if_not_exists():
logger.error("Failed to create database, exiting")
return
logger.success("Database created successfully")

run_migrations()
self._init_db_from_symlinks()
Expand Down
11 changes: 5 additions & 6 deletions src/utils/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from loguru import logger
from sqlalchemy.orm.exc import StaleDataError
from subliminal import Episode, Movie
from concurrent.futures import CancelledError

import utils.websockets.manager as ws_manager
from program.db.db import db
Expand Down Expand Up @@ -69,12 +69,11 @@ def _process_future(self, future, service):
if item:
self.remove_item_from_running(item)
self.add_event(Event(emitted_by=service, item=item, run_at=timestamp))
except concurrent.futures.CancelledError:
# This is expected behavior when cancelling tasks
return
except StaleDataError:
# This is expected behavior when cancelling tasks
except (StaleDataError, CancelledError):
# Expected behavior when cancelling tasks or when the item was removed
return
except ValueError as e:
logger.error(f"Error in future for {future}: {e}")
except Exception as e:
logger.error(f"Error in future for {future}: {e}")
logger.exception(traceback.format_exc())
Expand Down

0 comments on commit bf4fc0e

Please sign in to comment.