Skip to content

Commit

Permalink
feat: Add episodes database (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
janw committed Apr 17, 2024
1 parent 6741508 commit ce99e0c
Show file tree
Hide file tree
Showing 17 changed files with 569 additions and 161 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*.mp3
*.m4a
*.yaml
*.db
!.github/**/*.yaml

archive/
Expand Down
32 changes: 26 additions & 6 deletions podcast_archiver/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from podcast_archiver import __version__ as version
from podcast_archiver import constants
from podcast_archiver.base import PodcastArchiver
from podcast_archiver.config import Settings
from podcast_archiver.config import Settings, in_ci
from podcast_archiver.console import console
from podcast_archiver.exceptions import InvalidSettings
from podcast_archiver.logging import configure_logging
Expand All @@ -32,6 +32,7 @@
"--opml",
"--dir",
"--config",
"--database",
],
},
{
Expand All @@ -47,6 +48,7 @@
"options": [
"--update",
"--max-episodes",
"--ignore-database",
],
},
]
Expand Down Expand Up @@ -155,7 +157,6 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b
resolve_path=True,
path_type=pathlib.Path,
),
show_default=True,
required=False,
default=pathlib.Path("."),
show_envvar=True,
Expand Down Expand Up @@ -232,6 +233,7 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b
"maximum_episode_count",
type=int,
default=0,
show_envvar=True,
help=Settings.model_fields["maximum_episode_count"].description,
)
@click.version_option(
Expand All @@ -252,15 +254,33 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b
@click.option(
"-c",
"--config",
"config_path",
"config",
type=ConfigFile(),
default=get_default_config_path,
show_default=False,
default=get_default_config_path(),
show_default=not in_ci(),
is_eager=True,
envvar=constants.ENVVAR_PREFIX + "_CONFIG",
show_envvar=True,
help="Path to a config file. Command line arguments will take precedence.",
)
@click.option(
"--database",
type=click.Path(
exists=False,
dir_okay=False,
resolve_path=True,
),
default=None,
show_envvar=True,
help=Settings.model_fields["database"].description,
)
@click.option(
"--ignore-database",
type=bool,
is_flag=True,
show_envvar=True,
help=Settings.model_fields["ignore_database"].description,
)
@click.pass_context
def main(ctx: click.RichContext, /, **kwargs: Any) -> int:
configure_logging(kwargs["verbose"])
Expand All @@ -278,7 +298,7 @@ def main(ctx: click.RichContext, /, **kwargs: Any) -> int:
pa.run()
except InvalidSettings as exc:
raise click.BadParameter(f"Invalid settings: {exc}") from exc
except KeyboardInterrupt as exc:
except KeyboardInterrupt as exc: # pragma: no cover
raise click.Abort("Interrupted by user") from exc
except FileNotFoundError as exc:
raise click.Abort(exc) from exc
Expand Down
59 changes: 49 additions & 10 deletions podcast_archiver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,29 @@
import pathlib
import textwrap
from datetime import datetime
from functools import cached_property
from os import getenv
from typing import IO, Any, Text

import pydantic
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, DirectoryPath, Field, FilePath
from pydantic import (
AnyHttpUrl,
BaseModel,
BeforeValidator,
DirectoryPath,
Field,
FilePath,
NewPath,
)
from pydantic import ConfigDict as _ConfigDict
from pydantic_core import to_json
from typing_extensions import Annotated
from yaml import YAMLError, safe_load

from podcast_archiver import __version__ as version
from podcast_archiver import constants
from podcast_archiver.console import console
from podcast_archiver.database import BaseDatabase, Database, DummyDatabase
from podcast_archiver.exceptions import InvalidSettings
from podcast_archiver.models import ALL_FIELD_TITLES_STR
from podcast_archiver.utils import FilenameFormatter


def expanduser(v: pathlib.Path) -> pathlib.Path:
Expand All @@ -29,6 +36,12 @@ def expanduser(v: pathlib.Path) -> pathlib.Path:

UserExpandedDir = Annotated[DirectoryPath, BeforeValidator(expanduser)]
UserExpandedFile = Annotated[FilePath, BeforeValidator(expanduser)]
UserExpandedPossibleFile = Annotated[FilePath | NewPath, BeforeValidator(expanduser)]


def in_ci() -> bool:
val = getenv("CI", "").lower()
return val.lower() in ("true", "1")


class Settings(BaseModel):
Expand Down Expand Up @@ -108,7 +121,22 @@ class Settings(BaseModel):
description=f"Download only the first {constants.DEBUG_PARTIAL_SIZE} bytes of episodes for debugging purposes.",
)

config_path: FilePath | None = Field(
database: UserExpandedPossibleFile | None = Field(
default=None,
description=(
"Location of the database to keep track of downloaded episodes. By default, the database will be created "
f"as '{constants.DEFAULT_DATABASE_FILENAME}' in the directory of the config file."
),
)
ignore_database: bool = Field(
default=False,
description=(
"Ignore the episodes database when downloading. This will cause files to be downloaded again, even if they "
"already exist in the database."
),
)

config: FilePath | None = Field(
default=None,
exclude=True,
)
Expand All @@ -133,7 +161,7 @@ def load_from_yaml(cls, path: pathlib.Path) -> Settings:
if not isinstance(content, dict):
raise InvalidSettings("Not a valid YAML document")

content.update(config_path=path)
content.update(config=path)
return cls.load_from_dict(content)

@classmethod
Expand All @@ -147,7 +175,7 @@ def generate_default_config(cls, file: IO[Text] | None = None) -> None:
]

for name, field in cls.model_fields.items():
if name in ("config_path",):
if name in ("config",):
continue
cli_opt = (
wrapper.wrap(f"Equivalent command line option: --{field.alias.replace('_', '-')}")
Expand All @@ -166,11 +194,22 @@ def generate_default_config(cls, file: IO[Text] | None = None) -> None:

contents = "\n".join(lines).strip()
if not file:
from podcast_archiver.console import console

console.print(contents, highlight=False)
return
with file:
file.write(contents + "\n")

@cached_property
def filename_formatter(self) -> FilenameFormatter:
return FilenameFormatter(self)
def get_database(self) -> BaseDatabase:
if getenv("TESTING", "0").lower() in ("1", "true"):
return DummyDatabase()

if self.database:
db_path = str(self.database)

Check warning on line 209 in podcast_archiver/config.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/config.py#L209

Added line #L209 was not covered by tests
elif self.config:
db_path = str(self.config.parent / constants.DEFAULT_DATABASE_FILENAME)

Check warning on line 211 in podcast_archiver/config.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/config.py#L211

Added line #L211 was not covered by tests
else:
db_path = constants.DEFAULT_DATABASE_FILENAME

Check warning on line 213 in podcast_archiver/config.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/config.py#L213

Added line #L213 was not covered by tests

return Database(filename=db_path, ignore_existing=self.ignore_database)

Check warning on line 215 in podcast_archiver/config.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/config.py#L215

Added line #L215 was not covered by tests
97 changes: 97 additions & 0 deletions podcast_archiver/console.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,100 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from rich import progress
from rich.console import Console

if TYPE_CHECKING:
from podcast_archiver.config import Settings
from podcast_archiver.models import Episode
from podcast_archiver.types import ProgressCallback

console = Console()


PROGRESS_COLUMNS = (
progress.SpinnerColumn(finished_text="[bar.finished]✔[/]"),
progress.TextColumn("[blue]{task.fields[date]:%Y-%m-%d}"),
progress.TextColumn("[progress.description]{task.description}"),
progress.BarColumn(bar_width=25),
progress.TaskProgressColumn(),
progress.TimeRemainingColumn(),
progress.DownloadColumn(),
progress.TransferSpeedColumn(),
)


def noop_callback(total: int | None = None, completed: int | None = None) -> None:
pass


class ProgressDisplay:
disabled: bool

_progress: progress.Progress
_state: dict[Episode, progress.TaskID]

def __init__(self, settings: Settings) -> None:
self.disabled = settings.verbose > 1 or settings.quiet
self._progress = progress.Progress(
*PROGRESS_COLUMNS,
console=console,
disable=self.disabled,
)
self._progress.live.vertical_overflow = "visible"
self._state = {}

def _get_task_id(self, episode: Episode) -> progress.TaskID:
return self._state.get(episode, self.register(episode))

def __enter__(self) -> ProgressDisplay:
if not self.disabled:
self._progress.start()
return self

def __exit__(self, *args: Any) -> None:
if not self.disabled:
self._progress.stop()
self._state = {}

def shutdown(self) -> None:
for task in self._progress.tasks or []:
if not task.finished:
task.visible = False

Check warning on line 64 in podcast_archiver/console.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/console.py#L64

Added line #L64 was not covered by tests
self._progress.stop()

def register(self, episode: Episode) -> progress.TaskID:
task_id = self._progress.add_task(
description=episode.title,
date=episode.published_time,
total=episode.enclosure.length,
visible=False,
)
self._state[episode] = task_id
return task_id

def update(self, episode: Episode, visible: bool = True, **kwargs: Any) -> None:
if self.disabled:
return

Check warning on line 79 in podcast_archiver/console.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/console.py#L79

Added line #L79 was not covered by tests

task_id = self._get_task_id(episode)
self._progress.update(task_id, visible=visible, **kwargs)

Check warning on line 82 in podcast_archiver/console.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/console.py#L81-L82

Added lines #L81 - L82 were not covered by tests

def completed(self, episode: Episode, visible: bool = True, **kwargs: Any) -> None:
if self.disabled:
return

Check warning on line 86 in podcast_archiver/console.py

View check run for this annotation

Codecov / codecov/patch

podcast_archiver/console.py#L86

Added line #L86 was not covered by tests

task_id = self._get_task_id(episode)
self._progress.update(task_id, visible=visible, completed=episode.enclosure.length, **kwargs)

def get_callback(self, episode: Episode) -> ProgressCallback:
if self.disabled:
return noop_callback

task_id = self._get_task_id(episode)

def _callback(total: int | None = None, completed: int | None = None) -> None:
self._progress.update(task_id, total=total, completed=completed, visible=True)

return _callback
1 change: 1 addition & 0 deletions podcast_archiver/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
DEFAULT_ARCHIVE_DIRECTORY = pathlib.Path(".")
DEFAULT_FILENAME_TEMPLATE = "{show.title}/{episode.published_time:%Y-%m-%d} - {episode.title}.{ext}"
DEFAULT_CONCURRENCY = 4
DEFAULT_DATABASE_FILENAME = "podcast-archiver.db"
78 changes: 78 additions & 0 deletions podcast_archiver/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import sqlite3
from abc import abstractmethod
from contextlib import contextmanager
from threading import Lock
from typing import TYPE_CHECKING, Iterator

from podcast_archiver.logging import logger

if TYPE_CHECKING:
from podcast_archiver.models import Episode


class BaseDatabase:
@abstractmethod
def add(self, episode: Episode) -> None:
pass # pragma: no cover

@abstractmethod
def exists(self, episode: Episode) -> bool:
pass # pragma: no cover


class DummyDatabase(BaseDatabase):
def add(self, episode: Episode) -> None:
pass

def exists(self, episode: Episode) -> bool:
return False


class Database(BaseDatabase):
filename: str
ignore_existing: bool

lock = Lock()

def __init__(self, filename: str, ignore_existing: bool) -> None:
self.filename = filename
self.ignore_existing = ignore_existing
self.migrate()

@contextmanager
def get_conn(self) -> Iterator[sqlite3.Connection]:
with self.lock, sqlite3.connect(self.filename) as conn:
yield conn

def migrate(self) -> None:
logger.debug(f"Migrating database at {self.filename}")
with self.get_conn() as conn:
conn.execute(
"""\
CREATE TABLE IF NOT EXISTS episodes(
guid TEXT UNIQUE NOT NULL,
title TEXT
)"""
)

def add(self, episode: Episode) -> None:
with self.get_conn() as conn:
try:
conn.execute(
"INSERT INTO episodes(guid, title) VALUES (?, ?)",
(episode.guid, episode.title),
)
except sqlite3.IntegrityError:
logger.debug(f"Episode exists: {episode}")

def exists(self, episode: Episode) -> bool:
if self.ignore_existing:
return False
with self.get_conn() as conn:
result = conn.execute(
"SELECT EXISTS(SELECT 1 FROM episodes WHERE guid = ?)",
(episode.guid,),
)
return bool(result.fetchone()[0])
Loading

0 comments on commit ce99e0c

Please sign in to comment.