From 12f75f094f07fa2b15b9d81562d887dbf94c84f7 Mon Sep 17 00:00:00 2001 From: Kunal Mehta Date: Tue, 29 Oct 2024 18:00:58 -0400 Subject: [PATCH] Handle get_db_object() failures as DownloadException Our error handling expects that failures from DownloadJobs raise DownloadExceptions, so let's do that. This is roughly the same concept as a9f0590f746a, just in a different part of the code. The exception handling has to exist in each class (instead of `one_or_none()`) because we need to know the corresponding type to pass to DownloadException. Fixes #2274. --- .../securedrop_client/api_jobs/downloads.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/client/securedrop_client/api_jobs/downloads.py b/client/securedrop_client/api_jobs/downloads.py index 70802f102..7499bd69b 100644 --- a/client/securedrop_client/api_jobs/downloads.py +++ b/client/securedrop_client/api_jobs/downloads.py @@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile from typing import Any +from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.session import Session from securedrop_client.api_jobs.base import SingleObjectApiJob @@ -108,9 +109,10 @@ def call_decrypt(self, filepath: str, session: Session | None = None) -> str: """ raise NotImplementedError - def get_db_object(self, session: Session) -> File | Message: + def get_db_object(self, session: Session) -> File | Message | Reply: """ - Get the database object associated with this job. + Get the database object associated with this job; may raise + DownloadException if not found """ raise NotImplementedError @@ -242,7 +244,10 @@ def get_db_object(self, session: Session) -> Reply: """ Override DownloadJob. """ - return session.query(Reply).filter_by(uuid=self.uuid).one() + try: + return session.query(Reply).filter_by(uuid=self.uuid).one() + except NoResultFound: + raise DownloadException("Reply not found in database", Reply, self.uuid) def call_download_api(self, api: API, db_object: Reply) -> tuple[str, str]: """ @@ -298,7 +303,10 @@ def get_db_object(self, session: Session) -> Message: """ Override DownloadJob. """ - return session.query(Message).filter_by(uuid=self.uuid).one() + try: + return session.query(Message).filter_by(uuid=self.uuid).one() + except NoResultFound: + raise DownloadException("Message not found in database", Message, self.uuid) def call_download_api(self, api: API, db_object: Message) -> tuple[str, str]: """ @@ -354,7 +362,10 @@ def get_db_object(self, session: Session) -> File: """ Override DownloadJob. """ - return session.query(File).filter_by(uuid=self.uuid).one() + try: + return session.query(File).filter_by(uuid=self.uuid).one() + except NoResultFound: + raise DownloadException("File not found in database", File, self.uuid) def call_download_api(self, api: API, db_object: File) -> tuple[str, str]: """