diff --git a/.gitignore b/.gitignore index 32d7c46..fcaf227 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.mp3 *.m4a *.yaml +*.db !.github/**/*.yaml archive/ diff --git a/podcast_archiver/cli.py b/podcast_archiver/cli.py index 0c4531e..7557ccd 100644 --- a/podcast_archiver/cli.py +++ b/podcast_archiver/cli.py @@ -11,7 +11,7 @@ from podcast_archiver import __version__ as version from podcast_archiver import constants from podcast_archiver.base import PodcastArchiver -from podcast_archiver.config import Settings +from podcast_archiver.config import Settings, in_ci from podcast_archiver.console import console from podcast_archiver.exceptions import InvalidSettings from podcast_archiver.logging import configure_logging @@ -32,6 +32,7 @@ "--opml", "--dir", "--config", + "--database", ], }, { @@ -47,6 +48,7 @@ "options": [ "--update", "--max-episodes", + "--ignore-database", ], }, ] @@ -155,7 +157,6 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b resolve_path=True, path_type=pathlib.Path, ), - show_default=True, required=False, default=pathlib.Path("."), show_envvar=True, @@ -232,6 +233,7 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b "maximum_episode_count", type=int, default=0, + show_envvar=True, help=Settings.model_fields["maximum_episode_count"].description, ) @click.version_option( @@ -252,15 +254,33 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b @click.option( "-c", "--config", - "config_path", + "config", type=ConfigFile(), - default=get_default_config_path, - show_default=False, + default=get_default_config_path(), + show_default=not in_ci(), is_eager=True, envvar=constants.ENVVAR_PREFIX + "_CONFIG", show_envvar=True, help="Path to a config file. Command line arguments will take precedence.", ) +@click.option( + "--database", + type=click.Path( + exists=False, + dir_okay=False, + resolve_path=True, + ), + default=None, + show_envvar=True, + help=Settings.model_fields["database"].description, +) +@click.option( + "--ignore-database", + type=bool, + is_flag=True, + show_envvar=True, + help=Settings.model_fields["ignore_database"].description, +) @click.pass_context def main(ctx: click.RichContext, /, **kwargs: Any) -> int: configure_logging(kwargs["verbose"]) @@ -278,7 +298,7 @@ def main(ctx: click.RichContext, /, **kwargs: Any) -> int: pa.run() except InvalidSettings as exc: raise click.BadParameter(f"Invalid settings: {exc}") from exc - except KeyboardInterrupt as exc: + except KeyboardInterrupt as exc: # pragma: no cover raise click.Abort("Interrupted by user") from exc except FileNotFoundError as exc: raise click.Abort(exc) from exc diff --git a/podcast_archiver/config.py b/podcast_archiver/config.py index 3336632..af0f3a9 100644 --- a/podcast_archiver/config.py +++ b/podcast_archiver/config.py @@ -3,11 +3,19 @@ import pathlib import textwrap from datetime import datetime -from functools import cached_property +from os import getenv from typing import IO, Any, Text import pydantic -from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, DirectoryPath, Field, FilePath +from pydantic import ( + AnyHttpUrl, + BaseModel, + BeforeValidator, + DirectoryPath, + Field, + FilePath, + NewPath, +) from pydantic import ConfigDict as _ConfigDict from pydantic_core import to_json from typing_extensions import Annotated @@ -15,10 +23,9 @@ from podcast_archiver import __version__ as version from podcast_archiver import constants -from podcast_archiver.console import console +from podcast_archiver.database import BaseDatabase, Database, DummyDatabase from podcast_archiver.exceptions import InvalidSettings from podcast_archiver.models import ALL_FIELD_TITLES_STR -from podcast_archiver.utils import FilenameFormatter def expanduser(v: pathlib.Path) -> pathlib.Path: @@ -29,6 +36,12 @@ def expanduser(v: pathlib.Path) -> pathlib.Path: UserExpandedDir = Annotated[DirectoryPath, BeforeValidator(expanduser)] UserExpandedFile = Annotated[FilePath, BeforeValidator(expanduser)] +UserExpandedPossibleFile = Annotated[FilePath | NewPath, BeforeValidator(expanduser)] + + +def in_ci() -> bool: + val = getenv("CI", "").lower() + return val.lower() in ("true", "1") class Settings(BaseModel): @@ -108,7 +121,22 @@ class Settings(BaseModel): description=f"Download only the first {constants.DEBUG_PARTIAL_SIZE} bytes of episodes for debugging purposes.", ) - config_path: FilePath | None = Field( + database: UserExpandedPossibleFile | None = Field( + default=None, + description=( + "Location of the database to keep track of downloaded episodes. By default, the database will be created " + f"as '{constants.DEFAULT_DATABASE_FILENAME}' in the directory of the config file." + ), + ) + ignore_database: bool = Field( + default=False, + description=( + "Ignore the episodes database when downloading. This will cause files to be downloaded again, even if they " + "already exist in the database." + ), + ) + + config: FilePath | None = Field( default=None, exclude=True, ) @@ -133,7 +161,7 @@ def load_from_yaml(cls, path: pathlib.Path) -> Settings: if not isinstance(content, dict): raise InvalidSettings("Not a valid YAML document") - content.update(config_path=path) + content.update(config=path) return cls.load_from_dict(content) @classmethod @@ -147,7 +175,7 @@ def generate_default_config(cls, file: IO[Text] | None = None) -> None: ] for name, field in cls.model_fields.items(): - if name in ("config_path",): + if name in ("config",): continue cli_opt = ( wrapper.wrap(f"Equivalent command line option: --{field.alias.replace('_', '-')}") @@ -166,11 +194,22 @@ def generate_default_config(cls, file: IO[Text] | None = None) -> None: contents = "\n".join(lines).strip() if not file: + from podcast_archiver.console import console + console.print(contents, highlight=False) return with file: file.write(contents + "\n") - @cached_property - def filename_formatter(self) -> FilenameFormatter: - return FilenameFormatter(self) + def get_database(self) -> BaseDatabase: + if getenv("TESTING", "0").lower() in ("1", "true"): + return DummyDatabase() + + if self.database: + db_path = str(self.database) + elif self.config: + db_path = str(self.config.parent / constants.DEFAULT_DATABASE_FILENAME) + else: + db_path = constants.DEFAULT_DATABASE_FILENAME + + return Database(filename=db_path, ignore_existing=self.ignore_database) diff --git a/podcast_archiver/console.py b/podcast_archiver/console.py index a9463af..cc38486 100644 --- a/podcast_archiver/console.py +++ b/podcast_archiver/console.py @@ -1,3 +1,100 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from rich import progress from rich.console import Console +if TYPE_CHECKING: + from podcast_archiver.config import Settings + from podcast_archiver.models import Episode + from podcast_archiver.types import ProgressCallback + console = Console() + + +PROGRESS_COLUMNS = ( + progress.SpinnerColumn(finished_text="[bar.finished]✔[/]"), + progress.TextColumn("[blue]{task.fields[date]:%Y-%m-%d}"), + progress.TextColumn("[progress.description]{task.description}"), + progress.BarColumn(bar_width=25), + progress.TaskProgressColumn(), + progress.TimeRemainingColumn(), + progress.DownloadColumn(), + progress.TransferSpeedColumn(), +) + + +def noop_callback(total: int | None = None, completed: int | None = None) -> None: + pass + + +class ProgressDisplay: + disabled: bool + + _progress: progress.Progress + _state: dict[Episode, progress.TaskID] + + def __init__(self, settings: Settings) -> None: + self.disabled = settings.verbose > 1 or settings.quiet + self._progress = progress.Progress( + *PROGRESS_COLUMNS, + console=console, + disable=self.disabled, + ) + self._progress.live.vertical_overflow = "visible" + self._state = {} + + def _get_task_id(self, episode: Episode) -> progress.TaskID: + return self._state.get(episode, self.register(episode)) + + def __enter__(self) -> ProgressDisplay: + if not self.disabled: + self._progress.start() + return self + + def __exit__(self, *args: Any) -> None: + if not self.disabled: + self._progress.stop() + self._state = {} + + def shutdown(self) -> None: + for task in self._progress.tasks or []: + if not task.finished: + task.visible = False + self._progress.stop() + + def register(self, episode: Episode) -> progress.TaskID: + task_id = self._progress.add_task( + description=episode.title, + date=episode.published_time, + total=episode.enclosure.length, + visible=False, + ) + self._state[episode] = task_id + return task_id + + def update(self, episode: Episode, visible: bool = True, **kwargs: Any) -> None: + if self.disabled: + return + + task_id = self._get_task_id(episode) + self._progress.update(task_id, visible=visible, **kwargs) + + def completed(self, episode: Episode, visible: bool = True, **kwargs: Any) -> None: + if self.disabled: + return + + task_id = self._get_task_id(episode) + self._progress.update(task_id, visible=visible, completed=episode.enclosure.length, **kwargs) + + def get_callback(self, episode: Episode) -> ProgressCallback: + if self.disabled: + return noop_callback + + task_id = self._get_task_id(episode) + + def _callback(total: int | None = None, completed: int | None = None) -> None: + self._progress.update(task_id, total=total, completed=completed, visible=True) + + return _callback diff --git a/podcast_archiver/constants.py b/podcast_archiver/constants.py index 5a07b32..cf04cf1 100644 --- a/podcast_archiver/constants.py +++ b/podcast_archiver/constants.py @@ -18,3 +18,4 @@ DEFAULT_ARCHIVE_DIRECTORY = pathlib.Path(".") DEFAULT_FILENAME_TEMPLATE = "{show.title}/{episode.published_time:%Y-%m-%d} - {episode.title}.{ext}" DEFAULT_CONCURRENCY = 4 +DEFAULT_DATABASE_FILENAME = "podcast-archiver.db" diff --git a/podcast_archiver/database.py b/podcast_archiver/database.py new file mode 100644 index 0000000..c971206 --- /dev/null +++ b/podcast_archiver/database.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import sqlite3 +from abc import abstractmethod +from contextlib import contextmanager +from threading import Lock +from typing import TYPE_CHECKING, Iterator + +from podcast_archiver.logging import logger + +if TYPE_CHECKING: + from podcast_archiver.models import Episode + + +class BaseDatabase: + @abstractmethod + def add(self, episode: Episode) -> None: + pass # pragma: no cover + + @abstractmethod + def exists(self, episode: Episode) -> bool: + pass # pragma: no cover + + +class DummyDatabase(BaseDatabase): + def add(self, episode: Episode) -> None: + pass + + def exists(self, episode: Episode) -> bool: + return False + + +class Database(BaseDatabase): + filename: str + ignore_existing: bool + + lock = Lock() + + def __init__(self, filename: str, ignore_existing: bool) -> None: + self.filename = filename + self.ignore_existing = ignore_existing + self.migrate() + + @contextmanager + def get_conn(self) -> Iterator[sqlite3.Connection]: + with self.lock, sqlite3.connect(self.filename) as conn: + yield conn + + def migrate(self) -> None: + logger.debug(f"Migrating database at {self.filename}") + with self.get_conn() as conn: + conn.execute( + """\ + CREATE TABLE IF NOT EXISTS episodes( + guid TEXT UNIQUE NOT NULL, + title TEXT + )""" + ) + + def add(self, episode: Episode) -> None: + with self.get_conn() as conn: + try: + conn.execute( + "INSERT INTO episodes(guid, title) VALUES (?, ?)", + (episode.guid, episode.title), + ) + except sqlite3.IntegrityError: + logger.debug(f"Episode exists: {episode}") + + def exists(self, episode: Episode) -> bool: + if self.ignore_existing: + return False + with self.get_conn() as conn: + result = conn.execute( + "SELECT EXISTS(SELECT 1 FROM episodes WHERE guid = ?)", + (episode.guid,), + ) + return bool(result.fetchone()[0]) diff --git a/podcast_archiver/download.py b/podcast_archiver/download.py index d4265d5..15a5f35 100644 --- a/podcast_archiver/download.py +++ b/podcast_archiver/download.py @@ -1,12 +1,14 @@ from __future__ import annotations from threading import Event -from typing import IO, TYPE_CHECKING, Any +from typing import IO, TYPE_CHECKING from podcast_archiver import constants +from podcast_archiver.console import noop_callback from podcast_archiver.enums import DownloadResult from podcast_archiver.logging import logger from podcast_archiver.session import session +from podcast_archiver.types import EpisodeResult, ProgressCallback from podcast_archiver.utils import atomic_write if TYPE_CHECKING: @@ -37,40 +39,37 @@ def __init__( episode: Episode, *, target: Path, - feed_info: FeedInfo, debug_partial: bool = False, write_info_json: bool = False, - progress: rich_progress.Progress | None = None, + progress_callback: ProgressCallback = noop_callback, stop_event: Event | None = None, ) -> None: self.episode = episode self.target = target - self.feed_info = feed_info self._debug_partial = debug_partial self._write_info_json = write_info_json - self._progress = progress + self.progress_callback = progress_callback self.stop_event = stop_event or Event() - self.init_progress() - def __repr__(self) -> str: return f"EpisodeDownload({self})" def __str__(self) -> str: return str(self.episode) - def __call__(self) -> DownloadResult: + def __call__(self) -> EpisodeResult: try: return self.run() except Exception as exc: logger.error("Download failed", exc_info=exc) - return DownloadResult.FAILED + return EpisodeResult(self.episode, DownloadResult.FAILED) + + def run(self) -> EpisodeResult: + if self.target.exists(): + return EpisodeResult(self.episode, DownloadResult.ALREADY_EXISTS) - def run(self) -> DownloadResult: self.target.parent.mkdir(parents=True, exist_ok=True) self.write_info_json() - if result := self.preflight_check(): - return result response = session.get( self.episode.enclosure.url, @@ -80,53 +79,27 @@ def run(self) -> DownloadResult: ) response.raise_for_status() total_size = int(response.headers.get("content-length", "0")) - self.update_progress(total=total_size) + self.progress_callback(total=total_size) with atomic_write(self.target, mode="wb") as fp: receive_complete = self.receive_data(fp, response) if not receive_complete: self.target.unlink(missing_ok=True) - return DownloadResult.ABORTED + return EpisodeResult(self.episode, DownloadResult.ABORTED) logger.info("Completed download of %s", self.target) - return DownloadResult.COMPLETED_SUCCESSFULLY + return EpisodeResult(self.episode, DownloadResult.COMPLETED_SUCCESSFULLY) @property def infojsonfile(self) -> Path: return self.target.with_suffix(".info.json") - def init_progress(self) -> None: - if self._progress is None: - return - - self._task_id = self._progress.add_task( - description=self.episode.title, - date=self.episode.published_time, - total=self.episode.enclosure.length, - visible=False, - ) - - def update_progress(self, visible: bool = True, **kwargs: Any) -> None: - if self._task_id is None: - return - assert self._progress - self._progress.update(self._task_id, visible=visible, **kwargs) - - def preflight_check(self) -> DownloadResult | None: - if self.target_exists: - logger.debug("Pre-flight check on episode '%s': already exists.", self.episode.title) - size = self.target.stat().st_size - self.update_progress(total=size, completed=size) - return DownloadResult.ALREADY_EXISTS - - return None - def receive_data(self, fp: IO[str], response: Response) -> bool: total_written = 0 for chunk in response.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): total_written += fp.write(chunk) - self.update_progress(completed=total_written) + self.progress_callback(completed=total_written) if self._debug_partial and total_written >= constants.DEBUG_PARTIAL_SIZE: logger.debug("Partial download completed.") @@ -143,7 +116,3 @@ def write_info_json(self) -> None: logger.info("Writing episode metadata to %s", self.infojsonfile.name) with atomic_write(self.infojsonfile) as fp: fp.write(self.episode.model_dump_json(indent=2) + "\n") - - @property - def target_exists(self) -> bool: - return self.target.exists() diff --git a/podcast_archiver/models.py b/podcast_archiver/models.py index ce7e62c..940e34d 100644 --- a/podcast_archiver/models.py +++ b/podcast_archiver/models.py @@ -73,7 +73,13 @@ class Episode(BaseModel): shownotes: str | None = Field(None, repr=False) content: list[Content] | None = Field(None, repr=False, alias="content", exclude=True) - _feed_info: FeedInfo + guid: str = Field(default=None, alias="id") # type: ignore[assignment] + + def __hash__(self) -> int: + return hash(self.guid) + + def __eq__(self, other: Episode | Any) -> bool: + return isinstance(other, Episode) and self.guid == other.guid @field_validator("published_time", mode="before") @classmethod @@ -108,6 +114,15 @@ def populate_enclosure(self) -> Episode: self.original_filename = Path(self.enclosure.href.path).name if self.enclosure.href.path else "" return self + @model_validator(mode="after") + def ensure_guid(self) -> Episode: + if not self.guid: + # If no GUID is given, use the enclosure url instead + # See https://help.apple.com/itc/podcasts_connect/#/itcb54353390 + self.guid = self.enclosure.url + self.original_filename = Path(self.enclosure.href.path).name if self.enclosure.href.path else "" + return self + def _get_enclosure_url(self) -> Link: for link in self.links: if ( diff --git a/podcast_archiver/processor.py b/podcast_archiver/processor.py index 1957fb7..4b10f59 100644 --- a/podcast_archiver/processor.py +++ b/podcast_archiver/processor.py @@ -1,34 +1,26 @@ from __future__ import annotations -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from threading import Event from typing import TYPE_CHECKING from pydantic import AnyHttpUrl, ValidationError from requests import HTTPError -from rich import progress as rich_progress -from podcast_archiver.console import console +from podcast_archiver.console import ProgressDisplay, console from podcast_archiver.download import DownloadJob from podcast_archiver.enums import DownloadResult, QueueCompletionType from podcast_archiver.logging import logger -from podcast_archiver.models import Feed +from podcast_archiver.models import Episode, Feed, FeedInfo +from podcast_archiver.types import EpisodeResult, EpisodeResultsList from podcast_archiver.utils import FilenameFormatter if TYPE_CHECKING: - from podcast_archiver.config import Settings + from pathlib import Path -PROGRESS_COLUMNS = ( - rich_progress.SpinnerColumn(finished_text="[bar.finished]✔[/]"), - rich_progress.TextColumn("[blue]{task.fields[date]:%Y-%m-%d}"), - rich_progress.TextColumn("[progress.description]{task.description}"), - rich_progress.BarColumn(bar_width=25), - rich_progress.TaskProgressColumn(), - rich_progress.TimeRemainingColumn(), - rich_progress.DownloadColumn(), - rich_progress.TransferSpeedColumn(), -) + from podcast_archiver.config import Settings + from podcast_archiver.database import BaseDatabase @dataclass @@ -41,22 +33,19 @@ class ProcessingResult: class FeedProcessor: settings: Settings + database: BaseDatabase filename_formatter: FilenameFormatter pool_executor: ThreadPoolExecutor - progress: rich_progress.Progress + progress: ProgressDisplay stop_event: Event def __init__(self, settings: Settings) -> None: self.settings = settings self.filename_formatter = FilenameFormatter(settings) + self.database = settings.get_database() self.pool_executor = ThreadPoolExecutor(max_workers=self.settings.concurrency) - self.progress = rich_progress.Progress( - *PROGRESS_COLUMNS, - console=console, - disable=settings.verbose > 1 or settings.quiet, - ) - # self.progress.live.vertical_overflow = "visible" + self.progress = ProgressDisplay(settings) self.stop_event = Event() def process(self, url: AnyHttpUrl) -> ProcessingResult: @@ -77,54 +66,78 @@ def process(self, url: AnyHttpUrl) -> ProcessingResult: console.print(f"\n[bold bright_magenta]Downloading archive for: {feed.info.title}[/]\n") with self.progress: - futures, completion_msg = self._process_episodes(feed=feed) - self._handle_futures(futures, result=result) + episode_results, completion_msg = self._process_episodes(feed=feed) + self._handle_results(episode_results, result=result) console.print(f"\n[bar.finished]✔ {completion_msg}[/]") return result - def _process_episodes(self, feed: Feed) -> tuple[list[Future[DownloadResult]], QueueCompletionType]: - futures: list[Future[DownloadResult]] = [] + def _preflight_check(self, episode: Episode, target: Path) -> DownloadResult | None: + if self.database.exists(episode): + logger.debug("Pre-flight check on episode '%s': already in database.", episode.title) + self.progress.completed(episode) + return DownloadResult.ALREADY_EXISTS + + if target.exists(): + logger.debug("Pre-flight check on episode '%s': already on disk.", episode.title) + self.progress.completed(episode) + return DownloadResult.ALREADY_EXISTS + + return None + + def _process_episodes(self, feed: Feed) -> tuple[EpisodeResultsList, QueueCompletionType]: + results: EpisodeResultsList = [] for idx, episode in enumerate(feed.episode_iter(self.settings.maximum_episode_count), 1): - target = self.filename_formatter.format(episode=episode, feed_info=feed.info) - download_job = DownloadJob( - episode, - target=target, - feed_info=feed.info, - debug_partial=self.settings.debug_partial, - write_info_json=self.settings.write_info_json, - progress=self.progress, - stop_event=self.stop_event, - ) - if self.settings.update_archive and download_job.target_exists: - logger.info("Up to date with %r", episode) - return futures, QueueCompletionType.FOUND_EXISTING + if completion := self._process_episode(episode, feed.info, results): + return results, completion - logger.info("Queueing download for %r", episode) - futures.append(self.pool_executor.submit(download_job)) if (max_count := self.settings.maximum_episode_count) and idx == max_count: logger.info("Reached requested maximum episode count of %s", max_count) - return futures, QueueCompletionType.MAX_EPISODES + return results, QueueCompletionType.MAX_EPISODES - return futures, QueueCompletionType.COMPLETED + return results, QueueCompletionType.COMPLETED - def _handle_futures(self, futures: list[Future[DownloadResult]], *, result: ProcessingResult) -> None: - for future in futures: - try: - _result = future.result() - logger.debug("Got future result %s", _result) - except Exception: - result.failures += 1 - else: - result.success += 1 + def _process_episode( + self, episode: Episode, feed_info: FeedInfo, results: EpisodeResultsList + ) -> QueueCompletionType | None: + target = self.filename_formatter.format(episode=episode, feed_info=feed_info) + if result := self._preflight_check(episode, target): + results.append(EpisodeResult(episode, result)) + if self.settings.update_archive: + logger.info("Up to date with %r", episode) + return QueueCompletionType.FOUND_EXISTING + return None + + logger.info("Queueing download for %r", episode) + results.append( + self.pool_executor.submit( + DownloadJob( + episode, + target=target, + debug_partial=self.settings.debug_partial, + write_info_json=self.settings.write_info_json, + progress_callback=self.progress.get_callback(episode), + stop_event=self.stop_event, + ) + ) + ) + return None + + def _handle_results(self, episode_results: EpisodeResultsList, *, result: ProcessingResult) -> None: + for episode_result in episode_results: + if not isinstance(episode_result, EpisodeResult): + try: + episode_result = episode_result.result() + logger.debug("Got future result %s", episode_result) + except Exception: + result.failures += 1 + continue + self.database.add(episode_result.episode) + result.success += 1 def shutdown(self) -> None: self.stop_event.set() self.pool_executor.shutdown(cancel_futures=True) - - for task in self.progress.tasks or []: - if not task.finished: - task.visible = False - self.progress.stop() + self.progress.shutdown() logger.debug("Completed processor shutdown") diff --git a/podcast_archiver/types.py b/podcast_archiver/types.py new file mode 100644 index 0000000..0816f55 --- /dev/null +++ b/podcast_archiver/types.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from concurrent.futures import Future +from typing import TYPE_CHECKING, NamedTuple, Protocol, TypeAlias + +if TYPE_CHECKING: + from podcast_archiver.enums import DownloadResult + from podcast_archiver.models import Episode + + +class EpisodeResult(NamedTuple): + episode: Episode + result: DownloadResult + + +class ProgressCallback(Protocol): + def __call__(self, total: int | None = None, completed: int | None = None) -> None: ... + + +EpisodeResultsList: TypeAlias = list[Future[EpisodeResult] | EpisodeResult] diff --git a/podcast_archiver/utils.py b/podcast_archiver/utils.py index 34bdf85..cbb0a5e 100644 --- a/podcast_archiver/utils.py +++ b/podcast_archiver/utils.py @@ -98,7 +98,7 @@ def format_field(self, value: Any, format_spec: str) -> str: return slugify(formatted) return make_filename_safe(formatted) - def format(self, episode: Episode, feed_info: FeedInfo) -> Path: # type: ignore[override] # noqa: A003 + def format(self, episode: Episode, feed_info: FeedInfo) -> Path: # type: ignore[override] kwargs: FormatterKwargs = { "episode": episode, "show": feed_info, diff --git a/pyproject.toml b/pyproject.toml index 105aee1..5d8e8df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ lint.extend-select = [ "C90", # mccabe "FA", # future-annotations "TCH", # type-checking + "RUF", # ruff-specific ] lint.ignore = [ "SIM108", # if-else-block-instead-of-if-exp @@ -109,7 +110,7 @@ source = [ [tool.coverage.report] exclude_also = [ - "if TYPE_CHECKING:" + "if TYPE_CHECKING:", ] fail_under = 60 precision = 2 diff --git a/tests/conftest.py b/tests/conftest.py index 78605f0..2762ace 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os import re +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Iterable @@ -9,6 +10,8 @@ import pytest from pydantic_core import Url +from podcast_archiver.models import Episode, Link + if TYPE_CHECKING: from responses import RequestsMock @@ -50,6 +53,7 @@ def feed_lautsprecher_empty(responses: RequestsMock) -> Url: @pytest.fixture def feedobj_lautsprecher(responses: RequestsMock) -> Url: + responses.assert_all_requests_are_fired = False responses.add(responses.GET, MEDIA_URL, b"BLOB") return FEED_OBJ @@ -64,3 +68,18 @@ def tmp_path_cd(request: pytest.FixtureRequest, tmp_path: str) -> Iterable[str]: os.chdir(tmp_path) yield tmp_path os.chdir(request.config.invocation_params.dir) + + +@pytest.fixture +def episode() -> Episode: + return Episode( + title="Some Episode", + subtitle="The unreleased version", + author="Janw", + published_parsed=datetime(2023, 3, 12, 12, 34, 56, tzinfo=timezone.utc), + enclosure=Link( + rel="enclosure", + link_type="audio/mpeg", + href="http://nowhere.invalid/file.mp3", + ), + ) diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..0a6fc49 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from podcast_archiver.database import Database + +if TYPE_CHECKING: + from pathlib import Path + + from podcast_archiver.models import Episode + + +def test_add(tmp_path_cd: Path, episode: Episode) -> None: + db = Database("db.db", ignore_existing=False) + + assert (tmp_path_cd / "db.db").is_file() + assert not db.exists(episode) + db.add(episode) + assert db.exists(episode) + db.add(episode) + assert db.exists(episode) + + +def test_add_ignore_existing(tmp_path_cd: Path, episode: Episode) -> None: + db = Database("db.db", ignore_existing=True) + + assert not db.exists(episode) + db.add(episode) + assert not db.exists(episode) + + +def test_migrate_idempotency(tmp_path_cd: Path) -> None: + db = Database("db.db", ignore_existing=False) + + db.migrate() + db.migrate() + + assert (tmp_path_cd / "db.db").is_file() diff --git a/tests/test_download.py b/tests/test_download.py index 09f7916..d10c053 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -21,41 +21,36 @@ def test_download_job(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any]) -> None: feed = FeedPage.model_validate(feedobj_lautsprecher) episode = feed.episodes[0] - mock_progress = mock.Mock( - **{ - "add_task.return_value": 0, - "update.return_value": None, - } - ) - job = download.DownloadJob(episode=episode, feed_info=feed.feed, target=Path("file.mp3"), progress=mock_progress) + mock_progress = mock.Mock(return_value=None) + job = download.DownloadJob(episode=episode, target=Path("file.mp3"), progress_callback=mock_progress) result = job() - assert result == DownloadResult.COMPLETED_SUCCESSFULLY - mock_progress.add_task.assert_called_once() - mock_progress.update.assert_called() + assert result == (episode, DownloadResult.COMPLETED_SUCCESSFULLY) + mock_progress.assert_called() + assert mock_progress.call_count == 2 def test_download_already_exists(tmp_path_cd: Path, feedobj_lautsprecher_notconsumed: dict[str, Any]) -> None: feed = FeedPage.model_validate(feedobj_lautsprecher_notconsumed) episode = feed.episodes[0] - job = download.DownloadJob(episode=episode, feed_info=feed.feed, target=Path("file.mp3")) + job = download.DownloadJob(episode=episode, target=Path("file.mp3")) job.target.parent.mkdir(exist_ok=True) job.target.touch() result = job() - assert result == DownloadResult.ALREADY_EXISTS + assert result == (episode, DownloadResult.ALREADY_EXISTS) def test_download_aborted(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any]) -> None: feed = FeedPage.model_validate(feedobj_lautsprecher) episode = feed.episodes[0] - job = download.DownloadJob(episode=episode, feed_info=feed.feed, target=Path("file.mp3")) + job = download.DownloadJob(episode=episode, target=Path("file.mp3")) job.stop_event.set() result = job() - assert result == DownloadResult.ABORTED + assert result == (episode, DownloadResult.ABORTED) class PartialObjectMock(Protocol): @@ -84,11 +79,11 @@ def test_download_failed( if should_download: responses.add(responses.GET, MEDIA_URL, b"BLOB") - job = download.DownloadJob(episode=episode, feed_info=feed.feed, target=Path("file.mp3")) + job = download.DownloadJob(episode=episode, target=Path("file.mp3")) with failure_mode(side_effect=side_effect), caplog.at_level(logging.ERROR): result = job() - assert result == DownloadResult.FAILED + assert result == (episode, DownloadResult.FAILED) failure_rec = None for record in caplog.records: if record.message == "Download failed": @@ -105,13 +100,8 @@ def test_download_failed( def test_download_info_json(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any], write_info_json: bool) -> None: feed = FeedPage.model_validate(feedobj_lautsprecher) episode = feed.episodes[0] - job = download.DownloadJob( - episode=episode, - feed_info=feed.feed, - target=tmp_path_cd / "file.mp3", - write_info_json=write_info_json, - ) + job = download.DownloadJob(episode=episode, target=tmp_path_cd / "file.mp3", write_info_json=write_info_json) result = job() - assert result == DownloadResult.COMPLETED_SUCCESSFULLY + assert result == (episode, DownloadResult.COMPLETED_SUCCESSFULLY) assert job.infojsonfile.exists() == write_info_json diff --git a/tests/test_filenames.py b/tests/test_filenames.py index 56b4745..1122bdb 100644 --- a/tests/test_filenames.py +++ b/tests/test_filenames.py @@ -1,22 +1,9 @@ -from datetime import datetime, timezone - import pytest from podcast_archiver.config import Settings -from podcast_archiver.models import Episode, FeedInfo, Link +from podcast_archiver.models import Episode, FeedInfo from podcast_archiver.utils import FilenameFormatter -EPISODE = Episode( - title="Some Episode", - subtitle="The unreleased version", - author="Janw", - published_parsed=datetime(2023, 3, 12, 12, 34, 56, tzinfo=timezone.utc), - enclosure=Link( - rel="enclosure", - link_type="audio/mpeg", - href="http://nowhere.invalid/file.mp3", - ), -) FEED_INFO = FeedInfo( title="That Show", subtitle="The one that never came to be", @@ -50,10 +37,10 @@ ), ], ) -def test_filename_formatting(fname_tmpl: str, slugify: bool, expected_fname: str) -> None: +def test_filename_formatting(fname_tmpl: str, slugify: bool, expected_fname: str, episode: Episode) -> None: settings = Settings(filename_template=fname_tmpl, slugify_paths=slugify) formatter = FilenameFormatter(settings=settings) - result = formatter.format(EPISODE, feed_info=FEED_INFO) + result = formatter.format(episode, feed_info=FEED_INFO) assert str(result) == expected_fname diff --git a/tests/test_processor.py b/tests/test_processor.py new file mode 100644 index 0000000..dfac48b --- /dev/null +++ b/tests/test_processor.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from podcast_archiver.config import Settings +from podcast_archiver.enums import DownloadResult +from podcast_archiver.models import FeedPage +from podcast_archiver.processor import FeedProcessor + +if TYPE_CHECKING: + from pydantic_core import Url + + +@pytest.mark.parametrize( + "file_exists,database_exists,expected_result", + [ + (False, False, None), + (True, False, DownloadResult.ALREADY_EXISTS), + (False, True, DownloadResult.ALREADY_EXISTS), + ], +) +def test_preflight_check( + tmp_path_cd: Path, + feedobj_lautsprecher: Url, + file_exists: bool, + database_exists: bool, + expected_result: DownloadResult | None, +) -> None: + settings = Settings() + feed = FeedPage.model_validate(feedobj_lautsprecher) + episode = feed.episodes[0] + target = Path("file.mp3") + proc = FeedProcessor(settings) + if file_exists: + target.touch() + with patch.object(proc.database, "exists", return_value=database_exists): + result = proc._preflight_check(episode, target=target) + + assert result == expected_result + + +# def test_download_already_exists(tmp_path_cd: Path, feedobj_lautsprecher_notconsumed: dict[str, Any]) -> None: +# feed = FeedPage.model_validate(feedobj_lautsprecher_notconsumed) +# episode = feed.episodes[0] + +# job = download.DownloadJob(episode=episode, target=Path("file.mp3")) +# job.target.parent.mkdir(exist_ok=True) +# job.target.touch() +# result = job() + +# assert result == (episode, DownloadResult.ALREADY_EXISTS) + + +# def test_download_aborted(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any]) -> None: +# feed = FeedPage.model_validate(feedobj_lautsprecher) +# episode = feed.episodes[0] + +# job = download.DownloadJob(episode=episode, target=Path("file.mp3")) +# job.stop_event.set() +# result = job() + +# assert result == (episode, DownloadResult.ABORTED) + + +# class PartialObjectMock(Protocol): +# def __call__(self, side_effect: type[Exception]) -> mock.Mock: ... + + +# # mypy: disable-error-code="attr-defined" +# @pytest.mark.parametrize( +# "failure_mode, side_effect, should_download", +# [ +# (partial(mock.patch.object, download.session, "get"), HTTPError, False), +# (partial(mock.patch.object, utils.os, "fsync"), IOError, True), +# ], +# ) +# def test_download_failed( +# tmp_path_cd: Path, +# feedobj_lautsprecher_notconsumed: dict[str, Any], +# failure_mode: PartialObjectMock, +# side_effect: type[Exception], +# caplog: pytest.LogCaptureFixture, +# should_download: bool, +# responses: RequestsMock, +# ) -> None: +# feed = FeedPage.model_validate(feedobj_lautsprecher_notconsumed) +# episode = feed.episodes[0] +# if should_download: +# responses.add(responses.GET, MEDIA_URL, b"BLOB") + +# job = download.DownloadJob(episode=episode, target=Path("file.mp3")) +# with failure_mode(side_effect=side_effect), caplog.at_level(logging.ERROR): +# result = job() + +# assert result == (episode, DownloadResult.FAILED) +# failure_rec = None +# for record in caplog.records: +# if record.message == "Download failed": +# failure_rec = record +# break + +# assert failure_rec +# assert failure_rec.exc_info +# exc_type, _, _ = failure_rec.exc_info +# assert exc_type == side_effect, failure_rec.exc_info + + +# @pytest.mark.parametrize("write_info_json", [False, True]) +# def test_download_info_json(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any], write_info_json: bool) -> None: +# feed = FeedPage.model_validate(feedobj_lautsprecher) +# episode = feed.episodes[0] +# job = download.DownloadJob(episode=episode, target=tmp_path_cd / "file.mp3", write_info_json=write_info_json) +# result = job() + +# assert result == (episode, DownloadResult.COMPLETED_SUCCESSFULLY) +# assert job.infojsonfile.exists() == write_info_json