Skip to content

Commit

Permalink
feat: integrate dependency injection with kink library (#859)
Browse files Browse the repository at this point in the history
* feat: integrate dependency injection with kink library

- Added dependency injection using the kink library to manage API instances and service initialization.
- Updated various modules to utilize dependency injection for better modularity and testability.
- Refactored API initialization and validation logic to be more centralized and consistent.
- Enhanced Trakt, Plex, Overseerr, Mdblist, and Listrr services to use injected dependencies.
- Updated CLI and service modules to align with the new dependency injection approach.
- Modified pyproject.toml to include kink as a dependency.

* feat: enhance Trakt API with OAuth support and settings integration

- Updated TraktAPI to accept settings via TraktModel, enabling OAuth configuration.
- Added OAuth flow methods to handle authorization and token exchange.
- Integrated TraktOauthModel into TraktModel for structured OAuth settings.
- Modified API bootstrap to pass settings to TraktAPI.
- Ensured backward compatibility with existing settings structure.

* fix: assignment of trakt api key in oauth

* fix: duplicate import

* fix: correct TraktAPI settings initialization in constructor
  • Loading branch information
iPromKnight authored Nov 5, 2024
1 parent c80f609 commit ed5fb2c
Show file tree
Hide file tree
Showing 19 changed files with 176 additions and 95 deletions.
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ psutil = "^6.0.0"
python-dotenv = "^1.0.1"
requests-ratelimiter = "^0.7.0"
requests-cache = "^1.2.1"
kink = "^0.8.1"

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.352"
Expand Down
41 changes: 41 additions & 0 deletions src/program/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,42 @@
from .listrr_api import ListrrAPI, ListrrAPIError
from .trakt_api import TraktAPI, TraktAPIError
from .plex_api import PlexAPI, PlexAPIError
from .overseerr_api import OverseerrAPI, OverseerrAPIError
from .mdblist_api import MdblistAPI, MdblistAPIError
from program.settings.manager import settings_manager
from kink import di

def bootstrap_apis():
__setup_trakt()
__setup_plex()
__setup_mdblist()
__setup_overseerr()
__setup_listrr()

def __setup_trakt():
traktApi = TraktAPI(settings_manager.settings.content.trakt)
di[TraktAPI] = traktApi

def __setup_plex():
if not settings_manager.settings.updaters.plex.enabled:
return
plexApi = PlexAPI(settings_manager.settings.updaters.plex.token, settings_manager.settings.updaters.plex.url)
di[PlexAPI] = plexApi

def __setup_overseerr():
if not settings_manager.settings.content.overseerr.enabled:
return
overseerrApi = OverseerrAPI(settings_manager.settings.content.overseerr.api_key, settings_manager.settings.content.overseerr.url)
di[OverseerrAPI] = overseerrApi

def __setup_mdblist():
if not settings_manager.settings.content.mdblist.enabled:
return
mdblistApi = MdblistAPI(settings_manager.settings.content.mdblist.api_key)
di[MdblistAPI] = mdblistApi

def __setup_listrr():
if not settings_manager.settings.content.listrr.enabled:
return
listrrApi = ListrrAPI(settings_manager.settings.content.listrr.api_key)
di[ListrrAPI] = listrrApi
4 changes: 2 additions & 2 deletions src/program/apis/listrr_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from loguru import logger
from requests.exceptions import HTTPError

from kink import di
from program.apis.trakt_api import TraktAPI
from program.media.item import MediaItem
from program.utils.request import create_service_session, BaseRequestHandler, Session, ResponseType, ResponseObject, HttpMethod
Expand All @@ -25,7 +25,7 @@ def __init__(self, api_key: str):
session = create_service_session()
session.headers.update(self.headers)
self.request_handler = ListrrRequestHandler(session, base_url=self.BASE_URL)
self.trakt_api = TraktAPI(rate_limit=False)
self.trakt_api = di[TraktAPI]

def validate(self):
return self.request_handler.execute(HttpMethod.GET, self.BASE_URL)
Expand Down
4 changes: 2 additions & 2 deletions src/program/apis/overseerr_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from loguru import logger
from requests.exceptions import ConnectionError, RetryError
from urllib3.exceptions import MaxRetryError

from kink import di
from program.apis.trakt_api import TraktAPI
from program.media.item import MediaItem
from program.settings.manager import settings_manager
Expand All @@ -27,7 +27,7 @@ def __init__(self, api_key: str, base_url: str):
self.api_key = api_key
rate_limit_params = get_rate_limit_params(max_calls=1000, period=300)
session = create_service_session(rate_limit_params=rate_limit_params)
self.trakt_api = TraktAPI(rate_limit=False)
self.trakt_api = di[TraktAPI]
self.headers = {"X-Api-Key": self.api_key}
session.headers.update(self.headers)
self.request_handler = OverseerrRequestHandler(session, base_url=base_url)
Expand Down
11 changes: 9 additions & 2 deletions src/program/apis/plex_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def execute(self, method: HttpMethod, endpoint: str, overriden_response_type: Re
class PlexAPI:
"""Handles Plex API communication"""

def __init__(self, token: str, base_url: str, rss_urls: Optional[List[str]]):
self.rss_urls = rss_urls
def __init__(self, token: str, base_url: str):
self.rss_urls: Optional[List[str]] = None
self.token = token
self.BASE_URL = base_url
session = create_service_session()
Expand All @@ -43,6 +43,13 @@ def validate_account(self):
def validate_server(self):
self.plex_server = PlexServer(self.BASE_URL, token=self.token, session=self.request_handler.session, timeout=60)

def set_rss_urls(self, rss_urls: List[str]):
self.rss_urls = rss_urls

def clear_rss_urls(self):
self.rss_urls = None
self.rss_enabled = False

def validate_rss(self, url: str):
return self.request_handler.execute(HttpMethod.GET, url)

Expand Down
80 changes: 62 additions & 18 deletions src/program/apis/trakt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
from datetime import datetime
from types import SimpleNamespace
from typing import Union, List, Optional
from urllib.parse import urlencode
from requests import RequestException, Session
from program import MediaItem
from program.media import Movie, Show, Season, Episode
from program.settings.manager import settings_manager
from program.settings.models import TraktModel
from program.utils.request import get_rate_limit_params, create_service_session, logger, BaseRequestHandler, \
ResponseType, HttpMethod, ResponseObject

ResponseType, HttpMethod, ResponseObject, get_cache_params

class TraktAPIError(Exception):
"""Base exception for TraktApi related errors"""

class TraktRequestHandler(BaseRequestHandler):
def __init__(self, session: Session, request_logging: bool = False):
super().__init__(session, response_type=ResponseType.SIMPLE_NAMESPACE, custom_exception=TraktAPIError, request_logging=request_logging)
def __init__(self, session: Session, response_type=ResponseType.SIMPLE_NAMESPACE, request_logging: bool = False):
super().__init__(session, response_type=response_type, custom_exception=TraktAPIError, request_logging=request_logging)

def execute(self, method: HttpMethod, endpoint: str, **kwargs) -> ResponseObject:
return super()._request(method, endpoint, **kwargs)


class TraktAPI:
"""Handles Trakt API communication"""
BASE_URL = "https://api.trakt.tv"
Expand All @@ -29,16 +32,17 @@ class TraktAPI:
"short_list": re.compile(r"https://trakt.tv/lists/\d+")
}

def __init__(self, api_key: Optional[str] = None, rate_limit: bool = True):
self.api_key = api_key
rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) if rate_limit else None
session = create_service_session(
rate_limit_params=rate_limit_params,
use_cache=False
)
def __init__(self, settings: TraktModel):
self.settings = settings
self.oauth_client_id = self.settings.oauth.oauth_client_id
self.oauth_client_secret = self.settings.oauth.oauth_client_secret
self.oauth_redirect_uri = self.settings.oauth.oauth_redirect_uri
rate_limit_params = get_rate_limit_params(max_calls=1000, period=300)
trakt_cache = get_cache_params("trakt", 86400)
session = create_service_session(rate_limit_params=rate_limit_params, use_cache=True, cache_params=trakt_cache)
self.headers = {
"Content-type": "application/json",
"trakt-api-key": self.api_key or self.CLIENT_ID,
"trakt-api-key": self.CLIENT_ID,
"trakt-api-version": "2"
}
session.headers.update(self.headers)
Expand Down Expand Up @@ -148,15 +152,15 @@ def get_show(self, imdb_id: str) -> dict:
"""Wrapper for trakt.tv API show method."""
if not imdb_id:
return {}
url = f"https://api.trakt.tv/shows/{imdb_id}/seasons?extended=episodes,full"
url = f"{self.BASE_URL}/shows/{imdb_id}/seasons?extended=episodes,full"
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
return response.data if response.is_ok and response.data else {}

def get_show_aliases(self, imdb_id: str, item_type: str) -> List[dict]:
"""Wrapper for trakt.tv API show method."""
if not imdb_id:
return []
url = f"https://api.trakt.tv/{item_type}/{imdb_id}/aliases"
url = f"{self.BASE_URL}/{item_type}/{imdb_id}/aliases"
try:
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if response.is_ok and response.data:
Expand All @@ -178,7 +182,7 @@ def get_show_aliases(self, imdb_id: str, item_type: str) -> List[dict]:

def create_item_from_imdb_id(self, imdb_id: str, type: str = None) -> Optional[MediaItem]:
"""Wrapper for trakt.tv API search method."""
url = f"https://api.trakt.tv/search/imdb/{imdb_id}?extended=full"
url = f"{self.BASE_URL}/search/imdb/{imdb_id}?extended=full"
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if not response.is_ok or not response.data:
logger.error(
Expand All @@ -194,7 +198,7 @@ def create_item_from_imdb_id(self, imdb_id: str, type: str = None) -> Optional[M

def get_imdbid_from_tmdb(self, tmdb_id: str, type: str = "movie") -> Optional[str]:
"""Wrapper for trakt.tv API search method."""
url = f"https://api.trakt.tv/search/tmdb/{tmdb_id}" # ?extended=full
url = f"{self.BASE_URL}/search/tmdb/{tmdb_id}" # ?extended=full
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if not response.is_ok or not response.data:
return None
Expand All @@ -206,7 +210,7 @@ def get_imdbid_from_tmdb(self, tmdb_id: str, type: str = "movie") -> Optional[st

def get_imdbid_from_tvdb(self, tvdb_id: str, type: str = "show") -> Optional[str]:
"""Wrapper for trakt.tv API search method."""
url = f"https://api.trakt.tv/search/tvdb/{tvdb_id}"
url = f"{self.BASE_URL}/search/tvdb/{tvdb_id}"
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if not response.is_ok or not response.data:
return None
Expand All @@ -219,7 +223,7 @@ def get_imdbid_from_tvdb(self, tvdb_id: str, type: str = "show") -> Optional[str
def resolve_short_url(self, short_url) -> Union[str, None]:
"""Resolve short URL to full URL"""
try:
response = self.request_handler.execute(HttpMethod.GET, url=short_url, additional_headers={"Content-Type": "application/json", "Accept": "text/html"})
response = self.request_handler.execute(HttpMethod.GET, endpoint=short_url, additional_headers={"Content-Type": "application/json", "Accept": "text/html"})
if response.is_ok:
return response.response.url
else:
Expand Down Expand Up @@ -279,6 +283,46 @@ def map_item_from_data(self, data, item_type: str, show_genres: List[str] = None
logger.error(f"Unknown item type {item_type} for {data.title} not found in list of acceptable items")
return None

def perform_oauth_flow(self) -> str:
"""Initiate the OAuth flow and return the authorization URL."""
if not self.oauth_client_id or not self.oauth_client_secret or not self.oauth_redirect_uri:
logger.error("OAuth settings not found in Trakt settings")
raise TraktAPIError("OAuth settings not found in Trakt settings")

params = {
"response_type": "code",
"client_id": self.oauth_client_id,
"redirect_uri": self.oauth_redirect_uri,
}
return f"{self.BASE_URL}/oauth/authorize?{urlencode(params)}"

def handle_oauth_callback(self, api_key:str, code: str) -> bool:
"""Handle the OAuth callback and exchange the code for an access token."""
if not self.oauth_client_id or not self.oauth_client_secret or not self.oauth_redirect_uri:
logger.error("OAuth settings not found in Trakt settings")
return False

token_url = f"{self.BASE_URL}/oauth/token"
payload = {
"code": code,
"client_id": self.oauth_client_id,
"client_secret": self.oauth_client_secret,
"redirect_uri": self.oauth_redirect_uri,
"grant_type": "authorization_code",
}
headers = self.headers.copy()
headers["trakt-api-key"] = api_key
response = self.request_handler.execute(HttpMethod.POST, token_url, data=payload, additional_headers=headers)
if response.is_ok:
token_data = response.data
self.settings.access_token = token_data.get("access_token")
self.settings.refresh_token = token_data.get("refresh_token")
settings_manager.save() # Save the tokens to settings
return True
else:
logger.error(f"Failed to obtain OAuth token: {response.status_code}")
return False

def _get_imdb_id_from_list(self, namespaces: List[SimpleNamespace], id_type: str = None, _id: str = None,
type: str = None) -> Optional[str]:
"""Get the imdb_id from the list of namespaces."""
Expand Down
2 changes: 1 addition & 1 deletion src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from sqlalchemy import delete, desc, func, insert, inspect, select, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, joinedload, selectinload
from program.utils import root_dir

import alembic
from program.utils import root_dir
from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation
from program.services.libraries.symlink import fix_broken_symlinks
from program.settings.manager import settings_manager
Expand Down
8 changes: 7 additions & 1 deletion src/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PlexWatchlist,
TraktContent,
)
from program.apis import bootstrap_apis
from program.services.downloaders import Downloader
from program.services.indexers.trakt import TraktIndexer
from program.services.libraries import SymlinkLibrary
Expand Down Expand Up @@ -64,8 +65,11 @@ def __init__(self):
self.malloc_time = time.monotonic()-50
self.last_snapshot = None

def initialize_services(self):
def initialize_apis(self):
bootstrap_apis()

def initialize_services(self):
"""Initialize all services."""
self.requesting_services = {
Overseerr: Overseerr(),
PlexWatchlist: PlexWatchlist(),
Expand Down Expand Up @@ -122,13 +126,15 @@ def start(self):
latest_version = get_version()
logger.log("PROGRAM", f"Riven v{latest_version} starting!")

settings_manager.register_observer(self.initialize_apis)
settings_manager.register_observer(self.initialize_services)
os.makedirs(data_dir_path, exist_ok=True)

if not settings_manager.settings_file.exists():
logger.log("PROGRAM", "Settings file not found, creating default settings")
settings_manager.save()

self.initialize_apis()
self.initialize_services()

max_worker_env_vars = [var for var in os.environ if var.endswith("_MAX_WORKERS")]
Expand Down
4 changes: 3 additions & 1 deletion src/program/services/content/listrr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Listrr content module"""
from typing import Generator
from kink import di
from program.utils.request import logger
from program.media.item import MediaItem
from program.settings.manager import settings_manager
Expand All @@ -12,7 +13,7 @@ class Listrr:
def __init__(self):
self.key = "listrr"
self.settings = settings_manager.settings.content.listrr
self.api = ListrrAPI(self.settings.api_key)
self.api = None
self.initialized = self.validate()
if not self.initialized:
return
Expand Down Expand Up @@ -40,6 +41,7 @@ def validate(self) -> bool:
logger.error("Both Movie and Show lists are empty or not set.")
return False
try:
self.api = di[ListrrAPI]
response = self.api.validate()
if not response.is_ok:
logger.error(
Expand Down
5 changes: 3 additions & 2 deletions src/program/services/content/mdblist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Generator
from loguru import logger

from kink import di
from program.apis.mdblist_api import MdblistAPI
from program.media.item import MediaItem
from program.settings.manager import settings_manager
Expand All @@ -14,7 +14,7 @@ class Mdblist:
def __init__(self):
self.key = "mdblist"
self.settings = settings_manager.settings.content.mdblist
self.api = MdblistAPI(self.settings.api_key)
self.api = None
self.initialized = self.validate()
if not self.initialized:
return
Expand All @@ -30,6 +30,7 @@ def validate(self):
if not self.settings.lists:
logger.error("Mdblist is enabled, but list is empty.")
return False
self.api = di[MdblistAPI]
response = self.api.validate()
if "Invalid API key!" in response.response.text:
logger.error("Mdblist api key is invalid.")
Expand Down
Loading

0 comments on commit ed5fb2c

Please sign in to comment.