diff --git a/alembic/versions/7f682532afa2_add_download_error.py b/alembic/versions/7f682532afa2_add_download_error.py new file mode 100644 index 000000000..7ec01d5b4 --- /dev/null +++ b/alembic/versions/7f682532afa2_add_download_error.py @@ -0,0 +1,289 @@ +"""add download error + +Revision ID: 7f682532afa2 +Revises: fb657f2ee8a7 +Create Date: 2020-04-15 13:44:21.434312 + +""" +from alembic import op +import sqlalchemy as sa +from securedrop_client import db + +# revision identifiers, used by Alembic. +revision = "7f682532afa2" +down_revision = "fb657f2ee8a7" +branch_labels = None +depends_on = None + + +CREATE_TABLE_FILES_NEW = """ + CREATE TABLE files ( + id INTEGER NOT NULL, + uuid VARCHAR(36) NOT NULL, + filename VARCHAR(255) NOT NULL, + file_counter INTEGER NOT NULL, + size INTEGER NOT NULL, + download_url VARCHAR(255) NOT NULL, + is_downloaded BOOLEAN DEFAULT 0 NOT NULL, + is_decrypted BOOLEAN CONSTRAINT files_compare_is_downloaded_vs_is_decrypted CHECK (CASE WHEN is_downloaded = 0 THEN is_decrypted IS NULL ELSE 1 END), + download_error_id INTEGER, + is_read BOOLEAN DEFAULT 0 NOT NULL, + source_id INTEGER NOT NULL, + last_updated DATETIME NOT NULL, + CONSTRAINT pk_files PRIMARY KEY (id), + CONSTRAINT uq_messages_source_id_file_counter UNIQUE (source_id, file_counter), + CONSTRAINT uq_files_uuid UNIQUE (uuid), + CONSTRAINT ck_files_is_downloaded CHECK (is_downloaded IN (0, 1)), + CONSTRAINT ck_files_is_decrypted CHECK (is_decrypted IN (0, 1)), + CONSTRAINT fk_files_download_error_id_downloaderrors FOREIGN KEY(download_error_id) REFERENCES downloaderrors (id), + CONSTRAINT ck_files_is_read CHECK (is_read IN (0, 1)), + CONSTRAINT fk_files_source_id_sources FOREIGN KEY(source_id) REFERENCES sources (id) +); +""" + +CREATE_TABLE_FILES_OLD = """ + CREATE TABLE files ( + id INTEGER NOT NULL, + uuid VARCHAR(36) NOT NULL, + filename VARCHAR(255) NOT NULL, + file_counter INTEGER NOT NULL, + size INTEGER NOT NULL, + download_url VARCHAR(255) NOT NULL, + is_downloaded BOOLEAN DEFAULT 0 NOT NULL, + is_read BOOLEAN DEFAULT 0 NOT NULL, + is_decrypted BOOLEAN, + source_id INTEGER NOT NULL, + CONSTRAINT pk_files PRIMARY KEY (id), + CONSTRAINT fk_files_source_id_sources FOREIGN KEY(source_id) REFERENCES sources (id), + CONSTRAINT uq_messages_source_id_file_counter UNIQUE (source_id, file_counter), + CONSTRAINT uq_files_uuid UNIQUE (uuid), + CONSTRAINT files_compare_is_downloaded_vs_is_decrypted + CHECK (CASE WHEN is_downloaded = 0 THEN is_decrypted IS NULL ELSE 1 END), + CONSTRAINT ck_files_is_downloaded CHECK (is_downloaded IN (0, 1)), + CONSTRAINT ck_files_is_read CHECK (is_read IN (0, 1)), + CONSTRAINT ck_files_is_decrypted CHECK (is_decrypted IN (0, 1)) +); +""" + + +CREATE_TABLE_MESSAGES_NEW = """ + CREATE TABLE messages ( + id INTEGER NOT NULL, + uuid VARCHAR(36) NOT NULL, + filename VARCHAR(255) NOT NULL, + file_counter INTEGER NOT NULL, + size INTEGER NOT NULL, + download_url VARCHAR(255) NOT NULL, + is_downloaded BOOLEAN DEFAULT 0 NOT NULL, + is_decrypted BOOLEAN CONSTRAINT messages_compare_is_downloaded_vs_is_decrypted CHECK (CASE WHEN is_downloaded = 0 THEN is_decrypted IS NULL ELSE 1 END), + download_error_id INTEGER, + is_read BOOLEAN DEFAULT 0 NOT NULL, + content TEXT CONSTRAINT ck_message_compare_download_vs_content CHECK (CASE WHEN is_downloaded = 0 THEN content IS NULL ELSE 1 END), + source_id INTEGER NOT NULL, + last_updated DATETIME NOT NULL, + CONSTRAINT pk_messages PRIMARY KEY (id), + CONSTRAINT uq_messages_source_id_file_counter UNIQUE (source_id, file_counter), + CONSTRAINT uq_messages_uuid UNIQUE (uuid), + CONSTRAINT ck_messages_is_downloaded CHECK (is_downloaded IN (0, 1)), + CONSTRAINT ck_messages_is_decrypted CHECK (is_decrypted IN (0, 1)), + CONSTRAINT fk_messages_download_error_id_downloaderrors FOREIGN KEY(download_error_id) REFERENCES downloaderrors (id), + CONSTRAINT ck_messages_is_read CHECK (is_read IN (0, 1)), + CONSTRAINT fk_messages_source_id_sources FOREIGN KEY(source_id) REFERENCES sources (id) + ); +""" + +CREATE_TABLE_MESSAGES_OLD = """ + CREATE TABLE messages ( + id INTEGER NOT NULL, + uuid VARCHAR(36) NOT NULL, + source_id INTEGER NOT NULL, + filename VARCHAR(255) NOT NULL, + file_counter INTEGER NOT NULL, + size INTEGER NOT NULL, + content TEXT, + is_decrypted BOOLEAN, + is_downloaded BOOLEAN DEFAULT 0 NOT NULL, + is_read BOOLEAN DEFAULT 0 NOT NULL, + download_url VARCHAR(255) NOT NULL, + CONSTRAINT pk_messages PRIMARY KEY (id), + CONSTRAINT uq_messages_source_id_file_counter UNIQUE (source_id, file_counter), + CONSTRAINT uq_messages_uuid UNIQUE (uuid), + CONSTRAINT fk_messages_source_id_sources FOREIGN KEY(source_id) REFERENCES sources (id), + CONSTRAINT ck_message_compare_download_vs_content + CHECK (CASE WHEN is_downloaded = 0 THEN content IS NULL ELSE 1 END), + CONSTRAINT messages_compare_is_downloaded_vs_is_decrypted + CHECK (CASE WHEN is_downloaded = 0 THEN is_decrypted IS NULL ELSE 1 END), + CONSTRAINT ck_messages_is_decrypted CHECK (is_decrypted IN (0, 1)), + CONSTRAINT ck_messages_is_downloaded CHECK (is_downloaded IN (0, 1)), + CONSTRAINT ck_messages_is_read CHECK (is_read IN (0, 1)) + ); +""" + + +CREATE_TABLE_REPLIES_NEW = """ + CREATE TABLE replies ( + id INTEGER NOT NULL, + uuid VARCHAR(36) NOT NULL, + source_id INTEGER NOT NULL, + filename VARCHAR(255) NOT NULL, + file_counter INTEGER NOT NULL, + size INTEGER, + content TEXT, + is_decrypted BOOLEAN, + is_downloaded BOOLEAN, + download_error_id INTEGER, + journalist_id INTEGER, + last_updated DATETIME NOT NULL, + CONSTRAINT pk_replies PRIMARY KEY (id), + CONSTRAINT uq_messages_source_id_file_counter UNIQUE (source_id, file_counter), + CONSTRAINT uq_replies_uuid UNIQUE (uuid), + CONSTRAINT fk_replies_source_id_sources FOREIGN KEY(source_id) REFERENCES sources (id), + CONSTRAINT fk_replies_download_error_id_downloaderrors + FOREIGN KEY(download_error_id) REFERENCES downloaderrors (id), + CONSTRAINT fk_replies_journalist_id_users FOREIGN KEY(journalist_id) REFERENCES users (id), + CONSTRAINT replies_compare_download_vs_content + CHECK (CASE WHEN is_downloaded = 0 THEN content IS NULL ELSE 1 END), + CONSTRAINT replies_compare_is_downloaded_vs_is_decrypted + CHECK (CASE WHEN is_downloaded = 0 THEN is_decrypted IS NULL ELSE 1 END), + CONSTRAINT ck_replies_is_decrypted CHECK (is_decrypted IN (0, 1)), + CONSTRAINT ck_replies_is_downloaded CHECK (is_downloaded IN (0, 1)) + ); +""" + +CREATE_TABLE_REPLIES_OLD = """ + CREATE TABLE replies ( + id INTEGER NOT NULL, + uuid VARCHAR(36) NOT NULL, + source_id INTEGER NOT NULL, + filename VARCHAR(255) NOT NULL, + file_counter INTEGER NOT NULL, + size INTEGER, + content TEXT, + is_decrypted BOOLEAN, + is_downloaded BOOLEAN, + journalist_id INTEGER, + CONSTRAINT pk_replies PRIMARY KEY (id), + CONSTRAINT uq_messages_source_id_file_counter UNIQUE (source_id, file_counter), + CONSTRAINT uq_replies_uuid UNIQUE (uuid), + CONSTRAINT fk_replies_source_id_sources FOREIGN KEY(source_id) REFERENCES sources (id), + CONSTRAINT fk_replies_journalist_id_users FOREIGN KEY(journalist_id) REFERENCES users (id), + CONSTRAINT replies_compare_download_vs_content + CHECK (CASE WHEN is_downloaded = 0 THEN content IS NULL ELSE 1 END), + CONSTRAINT replies_compare_is_downloaded_vs_is_decrypted + CHECK (CASE WHEN is_downloaded = 0 THEN is_decrypted IS NULL ELSE 1 END), + CONSTRAINT ck_replies_is_decrypted CHECK (is_decrypted IN (0, 1)), + CONSTRAINT ck_replies_is_downloaded CHECK (is_downloaded IN (0, 1)) + ); +""" + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "downloaderrors", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=36), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_downloaderrors")), + sa.UniqueConstraint("name", name=op.f("uq_downloaderrors_name")), + ) + + conn = op.get_bind() + for name, member in db.DownloadErrorCodes.__members__.items(): + conn.execute("""INSERT INTO downloaderrors (name) VALUES (:name);""", name) + + op.rename_table("files", "files_tmp") + op.rename_table("messages", "messages_tmp") + op.rename_table("replies", "replies_tmp") + + conn.execute(CREATE_TABLE_FILES_NEW) + conn.execute(CREATE_TABLE_MESSAGES_NEW) + conn.execute(CREATE_TABLE_REPLIES_NEW) + + conn.execute(""" + INSERT INTO files + ( + id, uuid, filename, file_counter, size, download_url, + is_downloaded, is_read, is_decrypted, download_error_id, source_id, + last_updated + ) + SELECT id, uuid, filename, file_counter, size, download_url, + is_downloaded, is_read, is_decrypted, NULL, source_id, CURRENT_TIMESTAMP + FROM files_tmp + """) + + conn.execute(""" + INSERT INTO messages + ( + id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, is_read, download_error_id, download_url, last_updated + ) + SELECT id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, is_read, NULL, download_url, CURRENT_TIMESTAMP + FROM messages_tmp + """) + + conn.execute(""" + INSERT INTO replies + ( + id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, download_error_id, journalist_id, last_updated + ) + SELECT id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, NULL, journalist_id, CURRENT_TIMESTAMP + FROM replies_tmp + """) + + # Delete the old tables. + op.drop_table("files_tmp") + op.drop_table("messages_tmp") + op.drop_table("replies_tmp") + + # ### end Alembic commands ### + + +def downgrade(): + + conn = op.get_bind() + + op.rename_table("files", "files_tmp") + op.rename_table("messages", "messages_tmp") + op.rename_table("replies", "replies_tmp") + + conn.execute(CREATE_TABLE_FILES_OLD) + conn.execute(CREATE_TABLE_MESSAGES_OLD) + conn.execute(CREATE_TABLE_REPLIES_OLD) + + conn.execute(""" + INSERT INTO files + (id, uuid, filename, file_counter, size, download_url, + is_downloaded, is_read, is_decrypted, source_id) + SELECT id, uuid, filename, file_counter, size, download_url, + is_downloaded, is_read, is_decrypted, source_id + FROM files_tmp + """) + conn.execute(""" + INSERT INTO messages + (id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, is_read, download_url) + SELECT id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, is_read, download_url + FROM messages_tmp + """) + conn.execute(""" + INSERT INTO replies + (id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, journalist_id) + SELECT id, uuid, source_id, filename, file_counter, size, content, is_decrypted, + is_downloaded, journalist_id + FROM replies_tmp + """) + + # Delete the old tables. + op.drop_table("files_tmp") + op.drop_table("messages_tmp") + op.drop_table("replies_tmp") + + # Drop downloaderrors + op.drop_table("downloaderrors") + + # ### end Alembic commands ### diff --git a/create_dev_data.py b/create_dev_data.py index 808bb33a9..0365c51d2 100755 --- a/create_dev_data.py +++ b/create_dev_data.py @@ -5,22 +5,37 @@ import sys from securedrop_client.config import Config -from securedrop_client.db import Base, make_session_maker, ReplySendStatus, ReplySendStatusCodes +from securedrop_client import db sdc_home = sys.argv[1] -session = make_session_maker(sdc_home)() -Base.metadata.create_all(bind=session.get_bind()) +session = db.make_session_maker(sdc_home)() +db.Base.metadata.create_all(bind=session.get_bind()) -with open(os.path.join(sdc_home, Config.CONFIG_NAME), 'w') as f: - f.write(json.dumps({ - 'journalist_key_fingerprint': '65A1B5FF195B56353CC63DFFCC40EF1228271441', - })) +with open(os.path.join(sdc_home, Config.CONFIG_NAME), "w") as f: + f.write( + json.dumps( + {"journalist_key_fingerprint": "65A1B5FF195B56353CC63DFFCC40EF1228271441"} + ) + ) -for reply_send_status in ReplySendStatusCodes: +for reply_send_status in db.ReplySendStatusCodes: try: - reply_status = session.query(ReplySendStatus).filter_by( - name=reply_send_status.value).one() + reply_status = ( + session.query(db.ReplySendStatus) + .filter_by(name=reply_send_status.value) + .one() + ) except NoResultFound: - reply_status = ReplySendStatus(reply_send_status.value) + reply_status = db.ReplySendStatus(reply_send_status.value) session.add(reply_status) session.commit() + +for download_error in db.DownloadErrorCodes: + try: + download_error = ( + session.query(db.DownloadError).filter_by(name=download_error.name).one() + ) + except NoResultFound: + download_error = db.DownloadError(download_error.name) + session.add(download_error) + session.commit() diff --git a/securedrop_client/api_jobs/downloads.py b/securedrop_client/api_jobs/downloads.py index 5d5d89a7b..8e2f08390 100644 --- a/securedrop_client/api_jobs/downloads.py +++ b/securedrop_client/api_jobs/downloads.py @@ -15,7 +15,7 @@ from securedrop_client.api_jobs.base import ApiJob from securedrop_client.crypto import GpgHelper, CryptoError -from securedrop_client.db import File, Message, Reply +from securedrop_client.db import DownloadError, DownloadErrorCodes, File, Message, Reply from securedrop_client.storage import mark_as_decrypted, mark_as_downloaded, \ set_message_or_reply_content @@ -147,6 +147,11 @@ def _download(self, etag, download_path = self.call_download_api(api, db_object) if not self._check_file_integrity(etag, download_path): + download_error = session.query(DownloadError).filter_by( + name=DownloadErrorCodes.CHECKSUM_ERROR.name + ).one() + db_object.download_error = download_error + session.commit() exception = DownloadChecksumMismatchException( 'Downloaded file had an invalid checksum.', type(db_object), @@ -157,6 +162,7 @@ def _download(self, destination = db_object.location(self.data_dir) os.makedirs(os.path.dirname(destination), mode=0o700, exist_ok=True) shutil.move(download_path, destination) + db_object.download_error = None mark_as_downloaded(type(db_object), db_object.uuid, session) logger.info("File downloaded to {}".format(destination)) return destination @@ -173,6 +179,7 @@ def _decrypt(self, ''' try: original_filename = self.call_decrypt(filepath, session) + db_object.download_error = None mark_as_decrypted( type(db_object), db_object.uuid, session, original_filename=original_filename ) @@ -181,6 +188,11 @@ def _decrypt(self, ) except CryptoError as e: mark_as_decrypted(type(db_object), db_object.uuid, session, is_decrypted=False) + download_error = session.query(DownloadError).filter_by( + name=DownloadErrorCodes.DECRYPTION_ERROR.name + ).one() + db_object.download_error = download_error + session.commit() logger.debug("Failed to decrypt file: {}".format(os.path.basename(filepath))) raise DownloadDecryptionException( "Downloaded file could not be decrypted.", diff --git a/securedrop_client/db.py b/securedrop_client/db.py index c9cec14e0..780556f90 100644 --- a/securedrop_client/db.py +++ b/securedrop_client/db.py @@ -107,6 +107,12 @@ class Message(Base): nullable=True, ) + download_error_id = Column( + Integer, + ForeignKey('downloaderrors.id') + ) + download_error = relationship("DownloadError") + # This reflects read status stored on the server. is_read = Column(Boolean(name='is_read'), nullable=False, server_default=text("0")) @@ -123,6 +129,13 @@ class Message(Base): cascade="delete"), lazy="joined") + last_updated = Column( + DateTime, + nullable=False, + default=datetime.datetime.utcnow, + onupdate=datetime.datetime.utcnow, + ) + def __init__(self, **kwargs: Any) -> None: if 'file_counter' in kwargs: raise TypeError('Cannot manually set file_counter') @@ -137,6 +150,8 @@ def __str__(self) -> str: if self.content is not None: return self.content else: + if self.download_error is not None: + return self.download_error.explain(self.__class__.__name__) return '' def __repr__(self) -> str: @@ -181,6 +196,12 @@ class File(Base): nullable=True, ) + download_error_id = Column( + Integer, + ForeignKey('downloaderrors.id') + ) + download_error = relationship("DownloadError") + # This reflects read status stored on the server. is_read = Column(Boolean(name='is_read'), nullable=False, server_default=text("0")) @@ -190,6 +211,13 @@ class File(Base): cascade="delete"), lazy="joined") + last_updated = Column( + DateTime, + nullable=False, + default=datetime.datetime.utcnow, + onupdate=datetime.datetime.utcnow, + ) + def __init__(self, **kwargs: Any) -> None: if 'file_counter' in kwargs: raise TypeError('Cannot manually set file_counter') @@ -202,6 +230,8 @@ def __str__(self) -> str: Return something that's a useful string representation of the file. """ if self.is_downloaded: + if self.download_error is not None: + return self.download_error.explain(self.__class__.__name__) return "File: {}".format(self.filename) else: return '' @@ -264,6 +294,19 @@ class Reply(Base): nullable=True, ) + download_error_id = Column( + Integer, + ForeignKey('downloaderrors.id') + ) + download_error = relationship("DownloadError") + + last_updated = Column( + DateTime, + nullable=False, + default=datetime.datetime.utcnow, + onupdate=datetime.datetime.utcnow, + ) + def __init__(self, **kwargs: Any) -> None: if 'file_counter' in kwargs: raise TypeError('Cannot manually set file_counter') @@ -278,6 +321,8 @@ def __str__(self) -> str: if self.content is not None: return self.content else: + if self.download_error is not None: + return self.download_error.explain(self.__class__.__name__) return '' def __repr__(self) -> str: @@ -296,6 +341,40 @@ def location(self, data_dir: str) -> str: ) +class DownloadErrorCodes(Enum): + """ + Enumerated download failure modes, with templates as values. + + The templates are intended to be formatted with the class name of + a downloadable item. + """ + CHECKSUM_ERROR = "cannot download {object_type}" + DECRYPTION_ERROR = "cannot decrypt {object_type}" + + +class DownloadError(Base): + """ + Table of errors that can occur with downloadable items: File, Message, Reply. + """ + __tablename__ = 'downloaderrors' + + id = Column(Integer, primary_key=True) + name = Column(String(36), unique=True, nullable=False) + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return "".format(self.name) + + def explain(self, classname: str) -> str: + """ + Formats the explanation type with the supplied class name. + """ + return DownloadErrorCodes[self.name].value.format(object_type=classname.lower()) + + class DraftReply(Base): __tablename__ = 'draftreplies' diff --git a/securedrop_client/gui/widgets.py b/securedrop_client/gui/widgets.py index 9b7c080b2..ebb3ff1b2 100644 --- a/securedrop_client/gui/widgets.py +++ b/securedrop_client/gui/widgets.py @@ -955,6 +955,8 @@ def setup(self, controller): self.controller.reply_ready.connect(self.set_snippet) self.controller.file_ready.connect(self.set_snippet) self.controller.file_missing.connect(self.set_snippet) + self.controller.message_download_failed.connect(self.set_snippet) + self.controller.reply_download_failed.connect(self.set_snippet) def update(self, sources: List[Source]) -> List[str]: """ @@ -1887,39 +1889,62 @@ class SpeechBubble(QWidget): and journalist. """ - CSS = ''' - #speech_bubble { - min-width: 540px; - max-width: 540px; - background-color: #fff; - } - #message { - font-family: 'Source Sans Pro'; - font-weight: 400; - font-size: 15px; - background-color: #fff; - padding: 16px; - } - #color_bar { - min-height: 5px; - max-height: 5px; - background-color: #102781; - border: 0px; + CSS = { + "speech_bubble": """ + min-width: 540px; + max-width: 540px; + background-color: #fff; + """, + "message": """ + min-width: 508px; + max-width: 508px; + font-family: 'Source Sans Pro'; + font-weight: 400; + font-size: 15px; + background-color: #fff; + padding: 16px; + """, + "color_bar": """ + min-height: 5px; + max-height: 5px; + background-color: #102781; + border: 0px; + """ + } + + CSS_ERROR = { + "speech_bubble": """ + min-width: 540px; + max-width: 540px; + background-color: #fff; + """, + "message": """ + min-width: 508px; + max-width: 508px; + font-family: 'Source Sans Pro'; + font-weight: 400; + font-size: 15px; + font-style: italic; + background-color: rgba(255, 255, 255, 0.6); + padding: 16px; + """, + "color_bar": """ + min-height: 5px; + max-height: 5px; + background-color: #BCBFCD; + border: 0px; + """ } - ''' TOP_MARGIN = 28 BOTTOM_MARGIN = 10 - def __init__(self, message_uuid: str, text: str, update_signal, index: int) -> None: + def __init__(self, message_uuid: str, text: str, update_signal, + download_error_signal, index: int, error: bool = False) -> None: super().__init__() self.uuid = message_uuid self.index = index - # Set styles - self.setStyleSheet(self.CSS) - self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - # Set layout layout = QVBoxLayout() self.setLayout(layout) @@ -1937,10 +1962,10 @@ def __init__(self, message_uuid: str, text: str, update_signal, index: int) -> N self.color_bar.setObjectName('color_bar') # Speech bubble - speech_bubble = QWidget() - speech_bubble.setObjectName('speech_bubble') + self.speech_bubble = QWidget() + self.speech_bubble.setObjectName('speech_bubble') speech_bubble_layout = QVBoxLayout() - speech_bubble.setLayout(speech_bubble_layout) + self.speech_bubble.setLayout(speech_bubble_layout) speech_bubble_layout.addWidget(self.message) speech_bubble_layout.addWidget(self.color_bar) speech_bubble_layout.setContentsMargins(0, 0, 0, 0) @@ -1952,7 +1977,7 @@ def __init__(self, message_uuid: str, text: str, update_signal, index: int) -> N self.bubble_area_layout = QHBoxLayout() self.bubble_area_layout.setContentsMargins(0, self.TOP_MARGIN, 0, self.BOTTOM_MARGIN) bubble_area.setLayout(self.bubble_area_layout) - self.bubble_area_layout.addWidget(speech_bubble) + self.bubble_area_layout.addWidget(self.speech_bubble) # Add widget to layout layout.addWidget(bubble_area) @@ -1961,8 +1986,16 @@ def __init__(self, message_uuid: str, text: str, update_signal, index: int) -> N self.message.setTextInteractionFlags(Qt.TextSelectableByMouse) self.message.setContextMenuPolicy(Qt.NoContextMenu) + # Set styles + if error: + self.set_error_styles() + else: + self.set_normal_styles() + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + # Connect signals to slots update_signal.connect(self._update_text) + download_error_signal.connect(self.set_error) @pyqtSlot(str, str, str) def _update_text(self, source_id: str, message_uuid: str, text: str) -> None: @@ -1972,6 +2005,26 @@ def _update_text(self, source_id: str, message_uuid: str, text: str) -> None: """ if message_uuid == self.uuid: self.message.setText(text) + self.set_normal_styles() + + @pyqtSlot(str, str, str) + def set_error(self, source_uuid: str, uuid: str, text: str): + """ + Adjust style and text to indicate an error. + """ + if uuid == self.uuid: + self.message.setText(text) + self.set_error_styles() + + def set_normal_styles(self): + self.speech_bubble.setStyleSheet(self.CSS["speech_bubble"]) + self.message.setStyleSheet(self.CSS["message"]) + self.color_bar.setStyleSheet(self.CSS["color_bar"]) + + def set_error_styles(self): + self.speech_bubble.setStyleSheet(self.CSS_ERROR["speech_bubble"]) + self.message.setStyleSheet(self.CSS_ERROR["message"]) + self.color_bar.setStyleSheet(self.CSS_ERROR["color_bar"]) class MessageWidget(SpeechBubble): @@ -1979,8 +2032,9 @@ class MessageWidget(SpeechBubble): Represents an incoming message from the source. """ - def __init__(self, message_uuid: str, message: str, update_signal, index: int) -> None: - super().__init__(message_uuid, message, update_signal, index) + def __init__(self, message_uuid: str, message: str, update_signal, + download_error_signal, index: int, error: bool = False) -> None: + super().__init__(message_uuid, message, update_signal, download_error_signal, index, error) class ReplyWidget(SpeechBubble): @@ -1988,60 +2042,84 @@ class ReplyWidget(SpeechBubble): Represents a reply to a source. """ - CSS_MESSAGE_REPLY_FAILED = ''' - font-family: 'Source Sans Pro'; - font-weight: 400; - font-size: 15px; - background-color: #fff; - color: #3b3b3b; - padding: 16px; - ''' - - CSS_COLOR_BAR_REPLY_FAILED = ''' - min-height: 5px; - max-height: 5px; - background-color: #ff3366; - border: 0px; - ''' - - CSS_ERROR_MESSAGE_REPLY_FAILED = ''' + CSS = { + "color_bar": """ + min-height: 5px; + max-height: 5px; + background-color: #0065db; + border: 0px; + """, + "message": """ + min-width: 508px; + max-width: 508px; + font-family: 'Source Sans Pro'; + font-weight: 400; + font-size: 15px; + background-color: #fff; + color: #3b3b3b; + padding: 16px; + """, + "speech_bubble": """ + min-width: 540px; + max-width: 540px; + background-color: #fff; + """, + } + + CSS_ERROR_MESSAGE = """ font-family: 'Source Sans Pro'; font-weight: 500; font-size: 13px; color: #ff3366; - ''' - - CSS_MESSAGE_REPLY_SUCCEEDED = ''' - font-family: 'Source Sans Pro'; - font-weight: 400; - font-size: 15px; - background-color: #fff; - color: #3b3b3b; - padding: 16px; - ''' - - CSS_COLOR_BAR_REPLY_SUCCEEDED = ''' - min-height: 5px; - max-height: 5px; - background-color: #0065db; - border: 0px; - ''' - - CSS_MESSAGE_REPLY_PENDING = ''' - font-family: 'Source Sans Pro'; - font-weight: 400; - font-size: 15px; - color: #A9AAAD; - background-color: #F7F8FC; - padding: 16px; - ''' + """ - CSS_COLOR_BAR_REPLY_PENDING = ''' - min-height: 5px; - max-height: 5px; - background-color: #0065db; - border: 0px; - ''' + CSS_REPLY_FAILED = { + "color_bar": """ + min-height: 5px; + max-height: 5px; + background-color: #ff3366; + border: 0px; + """, + "message": """ + min-width: 508px; + max-width: 508px; + font-family: 'Source Sans Pro'; + font-weight: 400; + font-size: 15px; + background-color: #fff; + color: #3b3b3b; + padding: 16px; + """, + "speech_bubble": """ + min-width: 540px; + max-width: 540px; + background-color: #fff; + """, + } + + CSS_REPLY_PENDING = { + "color_bar": """ + min-height: 5px; + max-height: 5px; + background-color: #0065db; + border: 0px; + """, + "message": """ + min-width: 508px; + max-width: 508px; + font-family: 'Source Sans Pro'; + font-weight: 400; + font-size: 15px; + color: #A9AAAD; + background-color: #F7F8FC; + padding: 16px; + """, + "speech_bubble": """ + min-width: 540px; + max-width: 540px; + background-color: #fff; + """, + } def __init__( self, @@ -2049,11 +2127,13 @@ def __init__( message: str, reply_status: str, update_signal, + download_error_signal, message_succeeded_signal, message_failed_signal, index: int, + error: bool = False, ) -> None: - super().__init__(message_uuid, message, update_signal, index) + super().__init__(message_uuid, message, update_signal, download_error_signal, index, error) self.uuid = message_uuid error_icon = SvgLabel('error_icon.svg', svg_size=QSize(12, 12)) @@ -2061,7 +2141,7 @@ def __init__( error_icon.setFixedWidth(12) error_message = SecureQLabel('Failed to send', wordwrap=False) error_message.setObjectName('error_message') - error_message.setStyleSheet(self.CSS_ERROR_MESSAGE_REPLY_FAILED) + error_message.setStyleSheet(self.CSS_ERROR_MESSAGE) self.error = QWidget() error_layout = QHBoxLayout() @@ -2078,20 +2158,21 @@ def __init__( message_failed_signal.connect(self._on_reply_failure) # Set styles - self._set_reply_state(reply_status) + if error: + self.set_error_styles() + else: + self._set_reply_state(reply_status) def _set_reply_state(self, status: str) -> None: + logger.debug("Setting ReplyWidget state: %s", status) if status == 'SUCCEEDED': - self.message.setStyleSheet(self.CSS_MESSAGE_REPLY_SUCCEEDED) - self.color_bar.setStyleSheet(self.CSS_COLOR_BAR_REPLY_SUCCEEDED) + self.set_normal_styles() self.error.hide() elif status == 'FAILED': - self.message.setStyleSheet(self.CSS_MESSAGE_REPLY_FAILED) - self.color_bar.setStyleSheet(self.CSS_COLOR_BAR_REPLY_FAILED) + self.set_failed_styles() self.error.show() elif status == 'PENDING': - self.message.setStyleSheet(self.CSS_MESSAGE_REPLY_PENDING) - self.color_bar.setStyleSheet(self.CSS_COLOR_BAR_REPLY_PENDING) + self.set_pending_styles() @pyqtSlot(str, str, str) def _on_reply_success(self, source_id: str, message_uuid: str, content: str) -> None: @@ -2111,6 +2192,16 @@ def _on_reply_failure(self, message_uuid: str) -> None: if message_uuid == self.uuid: self._set_reply_state('FAILED') + def set_failed_styles(self): + self.speech_bubble.setStyleSheet(self.CSS_REPLY_FAILED["speech_bubble"]) + self.message.setStyleSheet(self.CSS_REPLY_FAILED["message"]) + self.color_bar.setStyleSheet(self.CSS_REPLY_FAILED["color_bar"]) + + def set_pending_styles(self): + self.speech_bubble.setStyleSheet(self.CSS_REPLY_PENDING["speech_bubble"]) + self.message.setStyleSheet(self.CSS_REPLY_PENDING["message"]) + self.color_bar.setStyleSheet(self.CSS_REPLY_PENDING["color_bar"]) + class FileWidget(QWidget): """ @@ -3171,7 +3262,8 @@ def add_file(self, file: File, index): self.controller, self.controller.file_ready, self.controller.file_missing, - index) + index, + ) self.conversation_layout.insertWidget(index, conversation_item, alignment=Qt.AlignLeft) self.current_messages[file.uuid] = conversation_item self.conversation_updated.emit() @@ -3190,7 +3282,13 @@ def add_message(self, message: Message, index) -> None: Add a message from the source. """ conversation_item = MessageWidget( - message.uuid, str(message), self.controller.message_ready, index) + message.uuid, + str(message), + self.controller.message_ready, + self.controller.message_download_failed, + index, + message.download_error is not None, + ) self.conversation_layout.insertWidget(index, conversation_item, alignment=Qt.AlignLeft) self.current_messages[message.uuid] = conversation_item self.conversation_updated.emit() @@ -3210,9 +3308,12 @@ def add_reply(self, reply: Union[DraftReply, Reply], index) -> None: str(reply), send_status, self.controller.reply_ready, + self.controller.reply_download_failed, self.controller.reply_succeeded, self.controller.reply_failed, - index) + index, + getattr(reply, "download_error", None) is not None, + ) self.conversation_layout.insertWidget(index, conversation_item, alignment=Qt.AlignRight) self.current_messages[reply.uuid] = conversation_item @@ -3226,6 +3327,7 @@ def add_reply_from_reply_box(self, uuid: str, content: str) -> None: content, 'PENDING', self.controller.reply_ready, + self.controller.reply_download_failed, self.controller.reply_succeeded, self.controller.reply_failed, index) diff --git a/securedrop_client/logic.py b/securedrop_client/logic.py index 694339a01..80ed41b94 100644 --- a/securedrop_client/logic.py +++ b/securedrop_client/logic.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ import arrow -from datetime import datetime +import datetime import functools import inspect import logging @@ -35,8 +35,8 @@ from securedrop_client import db from securedrop_client.api_jobs.base import ApiInaccessibleError from securedrop_client.api_jobs.downloads import ( - DownloadChecksumMismatchException, DownloadDecryptionException, FileDownloadJob, - MessageDownloadJob, ReplyDownloadJob, + DownloadChecksumMismatchException, DownloadDecryptionException, DownloadException, + FileDownloadJob, MessageDownloadJob, ReplyDownloadJob, ) from securedrop_client.api_jobs.sources import DeleteSourceJob, DeleteSourceJobException from securedrop_client.api_jobs.uploads import SendReplyJob, SendReplyJobError, \ @@ -155,6 +155,16 @@ class Controller(QObject): """ reply_ready = pyqtSignal(str, str, str) + """ + This signal indicates an error while downloading a reply. + + Emits: + str: the reply's source UUID + str: the reply UUID + str: the content of the reply + """ + reply_download_failed = pyqtSignal(str, str, str) + """ This signal indicates that a message has been successfully downloaded. @@ -165,6 +175,16 @@ class Controller(QObject): """ message_ready = pyqtSignal(str, str, str) + """ + This signal indicates an error while downloading a message. + + Emits: + str: the message's source UUID + str: the message UUID + str: the content of the message + """ + message_download_failed = pyqtSignal(str, str, str) + """ This signal indicates that a file has been successfully downloaded. @@ -216,8 +236,10 @@ class Controller(QObject): """ source_deletion_failed = pyqtSignal(str) - def __init__(self, hostname: str, gui, session_maker: sessionmaker, - home: str, proxy: bool = True, qubes: bool = True) -> None: + def __init__( + self, hostname: str, gui, session_maker: sessionmaker, + home: str, proxy: bool = True, qubes: bool = True + ) -> None: """ The hostname, gui and session objects are used to coordinate with the various other layers of the application: the location of the SecureDrop @@ -314,6 +336,8 @@ def setup(self): self.export.moveToThread(self.export_thread) self.export_thread.start() + storage.clear_download_errors(self.session) + def call_api(self, api_call_func, success_callback, @@ -604,13 +628,18 @@ def _submit_download_job(self, self.api_job_queue.enqueue(job) def download_new_messages(self) -> None: - messages = storage.find_new_messages(self.session) - - if len(messages) > 0: + new_messages = storage.find_new_messages(self.session) + new_message_count = len(new_messages) + if new_message_count > 0: self.set_status(_('Retrieving new messages'), 2500) - for message in messages: - self._submit_download_job(type(message), message.uuid) + for message in new_messages: + if message.download_error: + logger.info( + f"Download of message {message.uuid} failed since client start; not retrying." + ) + else: + self._submit_download_job(type(message), message.uuid) def on_message_download_success(self, uuid: str) -> None: """ @@ -620,21 +649,33 @@ def on_message_download_success(self, uuid: str) -> None: message = storage.get_message(self.session, uuid) self.message_ready.emit(message.source.uuid, message.uuid, message.content) - def on_message_download_failure(self, exception: Exception) -> None: + def on_message_download_failure(self, exception: DownloadException) -> None: """ Called when a message fails to download. """ logger.info('Failed to download message: {}'.format(exception)) - # Keep resubmitting the job if the download is corrupted. if isinstance(exception, DownloadChecksumMismatchException): + # Keep resubmitting the job if the download is corrupted. logger.warning('Failure due to checksum mismatch, retrying {}'.format(exception.uuid)) self._submit_download_job(exception.object_type, exception.uuid) + self.session.commit() + try: + message = storage.get_message(self.session, exception.uuid) + self.message_download_failed.emit(message.source.uuid, message.uuid, str(message)) + except Exception as e: + logger.error(f"Could not emit message_download_failed: {e}") + def download_new_replies(self) -> None: replies = storage.find_new_replies(self.session) for reply in replies: - self._submit_download_job(type(reply), reply.uuid) + if reply.download_error: + logger.info( + f"Download of reply {reply.uuid} failed since client start; not retrying." + ) + else: + self._submit_download_job(type(reply), reply.uuid) def on_reply_download_success(self, uuid: str) -> None: """ @@ -644,17 +685,24 @@ def on_reply_download_success(self, uuid: str) -> None: reply = storage.get_reply(self.session, uuid) self.reply_ready.emit(reply.source.uuid, reply.uuid, reply.content) - def on_reply_download_failure(self, exception: Exception) -> None: + def on_reply_download_failure(self, exception: DownloadException) -> None: """ Called when a reply fails to download. """ logger.info('Failed to download reply: {}'.format(exception)) - # Keep resubmitting the job if the download is corrupted. if isinstance(exception, DownloadChecksumMismatchException): + # Keep resubmitting the job if the download is corrupted. logger.warning('Failure due to checksum mismatch, retrying {}'.format(exception.uuid)) self._submit_download_job(exception.object_type, exception.uuid) + self.session.commit() + try: + reply = storage.get_reply(self.session, exception.uuid) + self.reply_download_failed.emit(reply.source.uuid, reply.uuid, str(reply)) + except Exception as e: + logger.error(f"Could not emit reply_download_failed: {e}") + def downloaded_file_exists(self, file: db.File) -> bool: ''' Check if the file specified by file_uuid exists. If it doesn't update the local db and @@ -766,7 +814,7 @@ def on_file_download_success(self, uuid: Any) -> None: """ self.session.commit() file_obj = storage.get_file(self.session, uuid) - # Let us update the size of the file. + file_obj.download_error = None storage.update_file_size(uuid, self.data_dir, self.session) self.file_ready.emit(file_obj.source.uuid, uuid, file_obj.filename) @@ -829,7 +877,7 @@ def send_reply(self, source_uuid: str, reply_uuid: str, message: str) -> None: name=db.ReplySendStatusCodes.PENDING.value).one() draft_reply = db.DraftReply( uuid=reply_uuid, - timestamp=datetime.utcnow(), + timestamp=datetime.datetime.utcnow(), source_id=source.id, journalist_id=self.api.token_journalist_uuid, file_counter=source.interaction_count, diff --git a/securedrop_client/storage.py b/securedrop_client/storage.py index 4880ceb35..1ffdef979 100644 --- a/securedrop_client/storage.py +++ b/securedrop_client/storage.py @@ -611,3 +611,13 @@ def mark_all_pending_drafts_as_failed(session: Session) -> List[DraftReply]: session.commit() return pending_drafts + + +def clear_download_errors(session: Session) -> None: + """ + Clears all File, Message, or Reply download errors. + """ + session.execute("""UPDATE files SET download_error_id = null;""") + session.execute("""UPDATE messages SET download_error_id = null;""") + session.execute("""UPDATE replies SET download_error_id = null;""") + session.commit() diff --git a/tests/api_jobs/test_downloads.py b/tests/api_jobs/test_downloads.py index 182f8d712..b9865ade8 100644 --- a/tests/api_jobs/test_downloads.py +++ b/tests/api_jobs/test_downloads.py @@ -166,7 +166,9 @@ def test_MessageDownloadJob_no_download_or_decrypt(mocker, homedir, session, ses assert message_is_decrypted_none.is_decrypted is True -def test_MessageDownloadJob_message_already_decrypted(mocker, homedir, session, session_maker): +def test_MessageDownloadJob_message_already_decrypted( + mocker, homedir, session, session_maker, download_error_codes +): """ Test that call_api just returns uuid if already decrypted. """ @@ -187,7 +189,9 @@ def test_MessageDownloadJob_message_already_decrypted(mocker, homedir, session, download_fn.assert_not_called() -def test_MessageDownloadJob_message_already_downloaded(mocker, homedir, session, session_maker): +def test_MessageDownloadJob_message_already_downloaded( + mocker, homedir, session, session_maker, download_error_codes +): """ Test that call_api just decrypts and returns uuid if already downloaded. """ @@ -256,7 +260,9 @@ def test_MessageDownloadJob_with_base_error(mocker, homedir, session, session_ma decrypt_fn.assert_not_called() -def test_MessageDownloadJob_with_crypto_error(mocker, homedir, session, session_maker): +def test_MessageDownloadJob_with_crypto_error( + mocker, homedir, session, session_maker, download_error_codes +): """ Test when a message successfully downloads, but does not successfully decrypt. Use the `homedir` fixture to get a GPG keyring. @@ -363,7 +369,9 @@ def fake_download(sdk_obj: SdkSubmission, timeout: int) -> Tuple[str, str]: assert mock_decrypt.called -def test_FileDownloadJob_happy_path_sha256_etag(mocker, homedir, session, session_maker): +def test_FileDownloadJob_happy_path_sha256_etag( + mocker, homedir, session, session_maker, download_error_codes +): source = factory.Source() file_ = factory.File(source=source, is_downloaded=None, is_decrypted=None) session.add(source) @@ -401,7 +409,9 @@ def fake_download(sdk_obj: SdkSubmission, timeout: int) -> Tuple[str, str]: assert mock_decrypt.called -def test_FileDownloadJob_bad_sha256_etag(mocker, homedir, session, session_maker): +def test_FileDownloadJob_bad_sha256_etag( + mocker, homedir, session, session_maker, download_error_codes +): source = factory.Source() file_ = factory.File(source=source, is_downloaded=None, is_decrypted=None) session.add(source) @@ -476,7 +486,9 @@ def fake_download(sdk_obj: SdkSubmission, timeout: int) -> Tuple[str, str]: assert mock_decrypt.called -def test_FileDownloadJob_decryption_error(mocker, homedir, session, session_maker): +def test_FileDownloadJob_decryption_error( + mocker, homedir, session, session_maker, download_error_codes +): source = factory.Source() file_ = factory.File(source=source, is_downloaded=None, is_decrypted=None) session.add(source) diff --git a/tests/conftest.py b/tests/conftest.py index 575abc285..183dcca75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,10 @@ from datetime import datetime from securedrop_client.config import Config from securedrop_client.app import configure_locale_and_language -from securedrop_client.db import (Base, make_session_maker, Source, ReplySendStatus, - ReplySendStatusCodes) +from securedrop_client.db import ( + Base, DownloadError, DownloadErrorCodes, ReplySendStatus, + ReplySendStatusCodes, Source, make_session_maker +) from uuid import uuid4 @@ -103,6 +105,15 @@ def reply_status_codes(session) -> None: return +@pytest.fixture(scope='function') +def download_error_codes(session) -> None: + for download_error_code in DownloadErrorCodes: + download_error = DownloadError(download_error_code.name) + session.add(download_error) + session.commit() + return + + @pytest.fixture(scope='function') def source(session) -> dict: args = { diff --git a/tests/factory.py b/tests/factory.py index 6981c6007..46d01be02 100644 --- a/tests/factory.py +++ b/tests/factory.py @@ -7,7 +7,7 @@ from typing import List import uuid -from sdclientapi import Source as SDKSource +from sdclientapi import Reply as SDKReply, Source as SDKSource from securedrop_client import db from securedrop_client.api_jobs.base import ApiJob @@ -195,3 +195,23 @@ def RemoteSource(**attrs): defaults.update(attrs) return SDKSource(**defaults) + + +def RemoteReply(**attrs): + + source_url = "/api/v1/sources/{}".format(str(uuid.uuid4())) + defaults = dict( + filename="1-reply.filename", + journalist_uuid=str(uuid.uuid4()), + journalist_username="test", + file_counter=1, + is_deleted_by_source=False, + reply_url="test", + size=1234, + uuid=str(uuid.uuid4()), + source_url=source_url, + ) + + defaults.update(attrs) + + return SDKReply(**defaults) diff --git a/tests/gui/test_widgets.py b/tests/gui/test_widgets.py index 5fcd0d2f1..67bfb1f7a 100644 --- a/tests/gui/test_widgets.py +++ b/tests/gui/test_widgets.py @@ -2024,16 +2024,45 @@ def test_SpeechBubble_init(mocker): Check the speech bubble is configured correctly (there's a label containing the passed in text). """ - mock_signal = mocker.Mock() - mock_connect = mocker.Mock() - mock_signal.connect = mock_connect + mock_update_signal = mocker.Mock() + mock_update_connect = mocker.Mock() + mock_update_signal.connect = mock_update_connect + + mock_download_error_signal = mocker.Mock() + mock_download_error_connect = mocker.Mock() + mock_download_error_signal.connect = mock_download_error_connect - sb = SpeechBubble('mock id', 'hello', mock_signal, 0) - ss = sb.styleSheet() + sb = SpeechBubble('mock id', 'hello', mock_update_signal, mock_download_error_signal, 0) sb.message.text() == 'hello' - assert mock_connect.called - assert 'background-color' in ss + assert mock_update_connect.called + assert mock_download_error_connect.called + assert 'background-color: #102781;' in sb.color_bar.styleSheet() + assert 'background-color: #fff;' in sb.speech_bubble.styleSheet() + + +def test_SpeechBubble_init_with_error(mocker): + """ + Check the speech bubble is configured correctly when error=True. + """ + mock_update_signal = mocker.Mock() + mock_update_connect = mocker.Mock() + mock_update_signal.connect = mock_update_connect + + mock_download_error_signal = mocker.Mock() + mock_download_error_connect = mocker.Mock() + mock_download_error_signal.connect = mock_download_error_connect + + sb = SpeechBubble( + 'mock id', 'hello', mock_update_signal, mock_download_error_signal, 0, error=True + ) + + sb.message.text() == 'hello' + assert mock_update_connect.called + assert mock_download_error_connect.called + assert 'background-color: #BCBFCD;' in sb.color_bar.styleSheet() + assert 'background-color: rgba(255, 255, 255, 0.6);' in sb.message.styleSheet() + assert 'font-style: italic;' in sb.message.styleSheet() def test_SpeechBubble_update_text(mocker): @@ -2043,7 +2072,7 @@ def test_SpeechBubble_update_text(mocker): mock_signal = mocker.MagicMock() msg_id = 'abc123' - sb = SpeechBubble(msg_id, 'hello', mock_signal, 0) + sb = SpeechBubble(msg_id, 'hello', mock_signal, mock_signal, 0) new_msg = 'new message' sb._update_text('mock_source_uuid', msg_id, new_msg) @@ -2061,7 +2090,7 @@ def test_SpeechBubble_html_init(mocker): """ mock_signal = mocker.MagicMock() - bubble = SpeechBubble('mock id', 'hello', mock_signal, 0) + bubble = SpeechBubble('mock id', 'hello', mock_signal, mock_signal, 0) assert bubble.message.text() == 'hello' @@ -2070,9 +2099,25 @@ def test_SpeechBubble_with_apostrophe_in_text(mocker): mock_signal = mocker.MagicMock() message = "I'm sure, you are reading my message." - bubble = SpeechBubble('mock id', message, mock_signal, 0) + bubble = SpeechBubble('mock id', message, mock_signal, mock_signal, 0) + assert bubble.message.text() == message + + +def test_SpeechBubble_set_error(mocker): + mock_signal = mocker.MagicMock() + + message_uuid = "mock id" + message = "I'm sure, you are reading my message." + bubble = SpeechBubble(message_uuid, message, mock_signal, mock_signal, 0) assert bubble.message.text() == message + error_message = "Oh no." + bubble.set_error("source id", message_uuid, error_message) + assert bubble.message.text() == error_message + assert "font-style: italic;" in bubble.message.styleSheet() + assert "background-color: rgba(255, 255, 255, 0.6);" in bubble.message.styleSheet() + assert "background-color: #BCBFCD;" in bubble.color_bar.styleSheet() + def test_MessageWidget_init(mocker): """ @@ -2082,7 +2127,7 @@ def test_MessageWidget_init(mocker): mock_connected = mocker.Mock() mock_signal.connect = mock_connected - MessageWidget('mock id', 'hello', mock_signal, 0) + MessageWidget('mock id', 'hello', mock_signal, mock_signal, 0) assert mock_connected.called @@ -2095,6 +2140,10 @@ def test_ReplyWidget_init(mocker): mock_update_connected = mocker.Mock() mock_update_signal.connect = mock_update_connected + mock_download_failure_signal = mocker.MagicMock() + mock_download_failure_connected = mocker.Mock() + mock_download_failure_signal.connect = mock_download_failure_connected + mock_success_signal = mocker.MagicMock() mock_success_connected = mocker.Mock() mock_success_signal.connect = mock_success_connected @@ -2108,15 +2157,57 @@ def test_ReplyWidget_init(mocker): 'hello', 'dummy', mock_update_signal, + mock_download_failure_signal, + mock_success_signal, + mock_failure_signal, + 0, + ) + + assert mock_update_connected.called + assert mock_success_connected.called + assert mock_failure_connected.called + + +def test_ReplyWidget_init_with_error(mocker): + """ + Check the CSS is set as expected when error=True. + """ + mock_update_signal = mocker.Mock() + mock_update_connected = mocker.Mock() + mock_update_signal.connect = mock_update_connected + + mock_download_failure_signal = mocker.MagicMock() + mock_download_failure_connected = mocker.Mock() + mock_download_failure_signal.connect = mock_download_failure_connected + + mock_success_signal = mocker.MagicMock() + mock_success_connected = mocker.Mock() + mock_success_signal.connect = mock_success_connected + + mock_failure_signal = mocker.MagicMock() + mock_failure_connected = mocker.Mock() + mock_failure_signal.connect = mock_failure_connected + + rw = ReplyWidget( + 'mock id', + 'hello', + 'dummy', + mock_update_signal, + mock_download_failure_signal, mock_success_signal, mock_failure_signal, 0, + error=True ) assert mock_update_connected.called - assert mock_success_connected.calledd + assert mock_success_connected.called assert mock_failure_connected.called + assert "font-style: italic;" in rw.message.styleSheet() + assert "background-color: rgba(255, 255, 255, 0.6);" in rw.message.styleSheet() + assert "background-color: #BCBFCD;" in rw.color_bar.styleSheet() + def test_FileWidget_init_file_not_downloaded(mocker, source, session): """ @@ -3342,7 +3433,12 @@ def test_ConversationView_add_message(mocker, session, source): source = source['source'] # grab the source from the fixture dict for simplicity mock_message_ready_signal = mocker.MagicMock() - mocked_controller = mocker.MagicMock(session=session, message_ready=mock_message_ready_signal) + mock_message_download_failed_signal = mocker.MagicMock() + mocked_controller = mocker.MagicMock( + session=session, + message_ready=mock_message_ready_signal, + message_download_failed=mock_message_download_failed_signal, + ) content = 'a sea, a bee' message = factory.Message(source=source, content=content) @@ -3361,7 +3457,14 @@ def test_ConversationView_add_message(mocker, session, source): cv.add_message(message, 0) # check that we built the widget was called with the correct args - mock_msg_widget.assert_called_once_with(message.uuid, content, mock_message_ready_signal, 0) + mock_msg_widget.assert_called_once_with( + message.uuid, + content, + mock_message_ready_signal, + mock_message_download_failed_signal, + 0, + False, + ) # check that we added the correct widget to the layout cv.conversation_layout.insertWidget.assert_called_once_with( @@ -3381,7 +3484,12 @@ def test_ConversationView_add_message_no_content(mocker, session, source): source = source['source'] # grab the source from the fixture dict for simplicity mock_message_ready_signal = mocker.MagicMock() - mocked_controller = mocker.MagicMock(session=session, message_ready=mock_message_ready_signal) + mock_message_download_failed_signal = mocker.MagicMock() + mocked_controller = mocker.MagicMock( + session=session, + message_ready=mock_message_ready_signal, + message_download_failed=mock_message_download_failed_signal + ) message = factory.Message(source=source, is_decrypted=False, content=None) session.add(message) @@ -3399,7 +3507,9 @@ def test_ConversationView_add_message_no_content(mocker, session, source): # check that we built the widget was called with the correct args mock_msg_widget.assert_called_once_with( - message.uuid, '', mock_message_ready_signal, 0) + message.uuid, '', mock_message_ready_signal, + mock_message_download_failed_signal, 0, False + ) # check that we added the correct widget to the layout cv.conversation_layout.insertWidget.assert_called_once_with( @@ -3443,10 +3553,15 @@ def test_ConversationView_add_reply_from_reply_box(mocker): """ source = factory.Source() reply_ready = mocker.MagicMock() + reply_download_failed = mocker.MagicMock() reply_succeeded = mocker.MagicMock() reply_failed = mocker.MagicMock() controller = mocker.MagicMock( - reply_ready=reply_ready, reply_succeeded=reply_succeeded, reply_failed=reply_failed) + reply_ready=reply_ready, + reply_download_failed=reply_download_failed, + reply_succeeded=reply_succeeded, + reply_failed=reply_failed + ) cv = ConversationView(source, controller) cv.conversation_layout = mocker.MagicMock() reply_widget_res = mocker.MagicMock() @@ -3456,7 +3571,9 @@ def test_ConversationView_add_reply_from_reply_box(mocker): cv.add_reply_from_reply_box('abc123', 'test message') reply_widget.assert_called_once_with( - 'abc123', 'test message', 'PENDING', reply_ready, reply_succeeded, reply_failed, 0) + 'abc123', 'test message', 'PENDING', reply_ready, reply_download_failed, + reply_succeeded, reply_failed, 0 + ) cv.conversation_layout.insertWidget.assert_called_once_with( 0, reply_widget_res, alignment=Qt.AlignRight) @@ -3468,12 +3585,16 @@ def test_ConversationView_add_reply(mocker, session, source): source = source['source'] # grab the source from the fixture dict for simplicity mock_reply_ready_signal = mocker.MagicMock() + mock_reply_download_failed_signal = mocker.MagicMock() mock_reply_succeeded_signal = mocker.MagicMock() mock_reply_failed_signal = mocker.MagicMock() - mocked_controller = mocker.MagicMock(session=session, - reply_ready=mock_reply_ready_signal, - reply_succeeded=mock_reply_succeeded_signal, - reply_failed=mock_reply_failed_signal) + mocked_controller = mocker.MagicMock( + session=session, + reply_ready=mock_reply_ready_signal, + reply_download_failed=mock_reply_download_failed_signal, + reply_succeeded=mock_reply_succeeded_signal, + reply_failed=mock_reply_failed_signal + ) content = 'a sea, a bee' reply = factory.Reply(source=source, content=content) @@ -3496,9 +3617,12 @@ def test_ConversationView_add_reply(mocker, session, source): content, 'SUCCEEDED', mock_reply_ready_signal, + mock_reply_download_failed_signal, mock_reply_succeeded_signal, mock_reply_failed_signal, - 0) + 0, + False + ) # check that we added the correct widget to the layout cv.conversation_layout.insertWidget.assert_called_once_with( @@ -3514,10 +3638,12 @@ def test_ConversationView_add_reply_no_content(mocker, session, source): source = source['source'] # grab the source from the fixture dict for simplicity mock_reply_ready_signal = mocker.MagicMock() + mock_reply_download_failed_signal = mocker.MagicMock() mock_reply_succeeded_signal = mocker.MagicMock() mock_reply_failed_signal = mocker.MagicMock() mocked_controller = mocker.MagicMock(session=session, reply_ready=mock_reply_ready_signal, + reply_download_failed=mock_reply_download_failed_signal, reply_succeeded=mock_reply_succeeded_signal, reply_failed=mock_reply_failed_signal) @@ -3541,9 +3667,12 @@ def test_ConversationView_add_reply_no_content(mocker, session, source): '', 'SUCCEEDED', mock_reply_ready_signal, + mock_reply_download_failed_signal, mock_reply_succeeded_signal, mock_reply_failed_signal, - 0) + 0, + False + ) # check that we added the correct widget to the layout cv.conversation_layout.insertWidget.assert_called_once_with( @@ -3923,6 +4052,7 @@ def pretend_source_was_deleted(self): def test_ReplyWidget_success_failure_slots(mocker): mock_update_signal = mocker.Mock() + mock_download_failed_signal = mocker.Mock() mock_success_signal = mocker.Mock() mock_failure_signal = mocker.Mock() msg_id = 'abc123' @@ -3931,6 +4061,7 @@ def test_ReplyWidget_success_failure_slots(mocker): 'lol', 'PENDING', mock_update_signal, + mock_download_failed_signal, mock_success_signal, mock_failure_signal, 0) @@ -3939,6 +4070,7 @@ def test_ReplyWidget_success_failure_slots(mocker): mock_success_signal.connect.assert_called_once_with(widget._on_reply_success) mock_failure_signal.connect.assert_called_once_with(widget._on_reply_failure) assert mock_update_signal.connect.called # to ensure no stale mocks + assert mock_download_failed_signal.connect.called # check the success slog widget._on_reply_success('mock_source_id', msg_id + "x", 'lol') diff --git a/tests/test_logic.py b/tests/test_logic.py index 9636f79c2..f3d076f59 100644 --- a/tests/test_logic.py +++ b/tests/test_logic.py @@ -3,6 +3,8 @@ expected. """ import arrow +import datetime +import logging import os import pytest @@ -20,6 +22,7 @@ from securedrop_client.api_jobs.updatestar import UpdateStarJobError, UpdateStarJobTimeoutError from securedrop_client.api_jobs.uploads import SendReplyJobError, SendReplyJobTimeoutError + with open(os.path.join(os.path.dirname(__file__), 'files', 'test-key.gpg.pub.asc')) as f: PUB_KEY = f.read() @@ -82,7 +85,7 @@ def test_Controller_init(homedir, config, mocker, session_maker): assert co.api_threads == {} -def test_Controller_setup(homedir, config, mocker, session_maker): +def test_Controller_setup(homedir, config, mocker, session_maker, session): """ Ensure the application is set up with the following default state: Using the `config` fixture to ensure the config is written to disk. @@ -1170,7 +1173,7 @@ def test_Controller_on_reply_downloaded_success(mocker, homedir, session_maker): co.on_reply_download_success(reply.uuid) - reply_ready.emit.assert_called_once_with(reply.source.uuid, reply.uuid, reply.content) + reply_ready.emit.assert_called_once_with(reply.source.uuid, reply.uuid, str(reply)) def test_Controller_on_reply_downloaded_failure(mocker, homedir, session_maker): @@ -1184,7 +1187,7 @@ def test_Controller_on_reply_downloaded_failure(mocker, homedir, session_maker): info_logger = mocker.patch('securedrop_client.logic.logger.info') co._submit_download_job = mocker.MagicMock() - co.on_reply_download_failure('mock_exception') + co.on_reply_download_failure(Exception('mock_exception')) info_logger.assert_called_once_with('Failed to download reply: mock_exception') reply_ready.emit.assert_not_called() @@ -1217,6 +1220,25 @@ def test_Controller_on_reply_downloaded_checksum_failure(mocker, homedir, sessio 'Failure due to checksum mismatch, retrying {}'.format(reply.uuid) +def test_Controller_on_reply_downloaded_decryption_failure(mocker, homedir, session_maker): + """ + Check that a failed download due to a decryption error informs the user. + """ + co = Controller('http://localhost', mocker.MagicMock(), session_maker, homedir) + reply_ready = mocker.patch.object(co, 'reply_ready') + reply_download_failed = mocker.patch.object(co, 'reply_download_failed') + reply = factory.Reply(source=factory.Source()) + mocker.patch('securedrop_client.storage.get_reply', return_value=reply) + info_logger = mocker.patch('securedrop_client.logic.logger.info') + + decryption_exception = DownloadDecryptionException('bang!', type(reply), reply.uuid) + co.on_reply_download_failure(decryption_exception) + + info_logger.call_args_list[0][0][0] == 'Failed to download reply: bang!' + reply_ready.emit.assert_not_called() + reply_download_failed.emit.assert_called_with(reply.source.uuid, reply.uuid, str(reply)) + + def test_Controller_download_new_messages_with_new_message(mocker, session, session_maker, homedir): """ Test that `download_new_messages` enqueues a job, connects to the right slots, and sets a @@ -1265,6 +1287,69 @@ def test_Controller_download_new_messages_without_messages(mocker, session, sess set_status.assert_not_called() +def test_Controller_download_new_messages_skips_recent_failures( + mocker, session, session_maker, homedir, download_error_codes +): + """ + Test that `download_new_messages` skips recently failed downloads. + """ + co = Controller("http://localhost", mocker.MagicMock(), session_maker, homedir) + co.api = "Api token has a value" + + # record the download failures + download_error = session.query(db.DownloadError).filter_by( + name=db.DownloadErrorCodes.DECRYPTION_ERROR.name + ).one() + + message = factory.Message(source=factory.Source()) + message.download_error = download_error + session.commit() + + mocker.patch("securedrop_client.storage.find_new_messages", return_value=[message]) + api_job_queue = mocker.patch.object(co, "api_job_queue") + mocker.patch("securedrop_client.logic.logger.isEnabledFor", return_value=logging.DEBUG) + info_logger = mocker.patch("securedrop_client.logic.logger.info") + + co.download_new_messages() + + api_job_queue.enqueue.assert_not_called() + info_logger.call_args_list[0][0][0] == ( + f"Download of message {message.uuid} failed since client start; not retrying." + ) + + +def test_Controller_download_new_replies_skips_recent_failures( + mocker, session, session_maker, homedir, download_error_codes +): + """ + Test that `download_new_replies` skips recently failed downloads. + """ + co = Controller("http://localhost", mocker.MagicMock(), session_maker, homedir) + co.api = "Api token has a value" + + # record the download failures + download_error = session.query(db.DownloadError).filter_by( + name=db.DownloadErrorCodes.DECRYPTION_ERROR.name + ).one() + + reply = factory.Reply(source=factory.Source()) + reply.download_error = download_error + reply.last_updated = datetime.datetime.utcnow() + session.commit() + + mocker.patch("securedrop_client.storage.find_new_replies", return_value=[reply]) + api_job_queue = mocker.patch.object(co, "api_job_queue") + mocker.patch("securedrop_client.logic.logger.isEnabledFor", return_value=logging.DEBUG) + info_logger = mocker.patch("securedrop_client.logic.logger.info") + + co.download_new_replies() + + api_job_queue.enqueue.assert_not_called() + info_logger.call_args_list[0][0][0] == ( + f"Download of reply {reply.uuid} failed since client start; not retrying." + ) + + def test_Controller_on_message_downloaded_success(mocker, homedir, session_maker): """ Check that a successful download emits proper signal. @@ -1276,7 +1361,7 @@ def test_Controller_on_message_downloaded_success(mocker, homedir, session_maker co.on_message_download_success(message.uuid) - message_ready.emit.assert_called_once_with(message.source.uuid, message.uuid, message.content) + message_ready.emit.assert_called_once_with(message.source.uuid, message.uuid, str(message)) def test_Controller_on_message_downloaded_failure(mocker, homedir, session_maker): @@ -1290,7 +1375,7 @@ def test_Controller_on_message_downloaded_failure(mocker, homedir, session_maker co._submit_download_job = mocker.MagicMock() info_logger = mocker.patch('securedrop_client.logic.logger.info') - co.on_message_download_failure('mock_exception') + co.on_message_download_failure(Exception('mock_exception')) info_logger.assert_called_once_with('Failed to download message: mock_exception') message_ready.emit.assert_not_called() @@ -1323,6 +1408,25 @@ def test_Controller_on_message_downloaded_checksum_failure(mocker, homedir, sess 'Failure due to checksum mismatch, retrying {}'.format(message.uuid) +def test_Controller_on_message_downloaded_decryption_failure(mocker, homedir, session_maker): + """ + Check that a failed download due to a decryption error informs the user. + """ + co = Controller('http://localhost', mocker.MagicMock(), session_maker, homedir) + message_ready = mocker.patch.object(co, 'message_ready') + message_download_failed = mocker.patch.object(co, 'message_download_failed') + message = factory.Message(source=factory.Source()) + mocker.patch('securedrop_client.storage.get_message', return_value=message) + info_logger = mocker.patch('securedrop_client.logic.logger.info') + + decryption_exception = DownloadDecryptionException('bang!', type(message), message.uuid) + co.on_message_download_failure(decryption_exception) + + info_logger.call_args_list[0][0][0] == 'Failed to download message: bang!' + message_ready.emit.assert_not_called() + message_download_failed.emit.assert_called_with(message.source.uuid, message.uuid, str(message)) + + def test_Controller_on_delete_source_success(mocker, homedir): ''' Test that on a successful deletion does not delete the source locally (regression). diff --git a/tests/test_models.py b/tests/test_models.py index 30adeb86c..520505c72 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,9 @@ import pytest from tests import factory -from securedrop_client.db import DraftReply, Reply, File, Message, ReplySendStatus, User +from securedrop_client.db import ( + DownloadError, DownloadErrorCodes, DraftReply, Reply, File, Message, ReplySendStatus, User +) def test_user_fullname(): @@ -122,6 +124,11 @@ def test_string_representation_of_send_reply_status(): reply_status.__repr__() +def test_string_representation_of_download_error(): + e = DownloadError(name="teehee") + assert repr(e) == "" + + def test_source_collection(): # Create some test submissions and replies source = factory.Source() @@ -239,3 +246,39 @@ def test_reply_init(): r = Reply(filename="1-foo") assert r.file_counter == 1 + + +def test_file_with_download_error(session, download_error_codes): + f = factory.File() + download_error = session.query(DownloadError).filter_by( + name=DownloadErrorCodes.CHECKSUM_ERROR.name + ).one() + f.download_error = download_error + session.commit() + + classname = f.__class__.__name__.lower() + assert str(f) == f"cannot download {classname}" + + +def test_message_with_download_error(session, download_error_codes): + m = factory.Message(is_decrypted=False, content=None) + download_error = session.query(DownloadError).filter_by( + name=DownloadErrorCodes.DECRYPTION_ERROR.name + ).one() + m.download_error = download_error + session.commit() + + classname = m.__class__.__name__.lower() + assert str(m) == f"cannot decrypt {classname}" + + +def test_reply_with_download_error(session, download_error_codes): + r = factory.Reply(is_decrypted=False, content=None) + download_error = session.query(DownloadError).filter_by( + name=DownloadErrorCodes.DECRYPTION_ERROR.name + ).one() + r.download_error = download_error + session.commit() + + classname = r.__class__.__name__.lower() + assert str(r) == f"cannot decrypt {classname}" diff --git a/tests/test_storage.py b/tests/test_storage.py index b7d0634d5..4e31eb386 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -687,21 +687,10 @@ def test_update_replies(homedir, mocker, session): source = factory.Source() session.add(source) - # Some remote reply objects from the API, one of which will exist in the - # local database, the other will NOT exist in the local database - # (this will be added to the database) - remote_reply_update = make_remote_reply(source.uuid, journalist.uuid) - remote_reply_create = make_remote_reply(source.uuid, journalist.uuid) - remote_reply_create.file_counter = 3 - remote_reply_create.filename = "3-reply.gpg" - - remote_replies = [remote_reply_update, remote_reply_create] - # Some local reply objects. One already exists in the API results # (this will be updated), one does NOT exist in the API results (this will # be deleted from the local database). local_reply_update = factory.Reply( - uuid=remote_reply_update.uuid, source_id=source.id, source=source, journalist_id=journalist.id, @@ -717,7 +706,27 @@ def test_update_replies(homedir, mocker, session): session.add(local_reply_delete) local_replies = [local_reply_update, local_reply_delete] + # Some remote reply objects from the API, one of which will exist in the + # local database, the other will NOT exist in the local database + # (this will be added to the database) + remote_reply_update = factory.RemoteReply( + journalist_uuid=journalist.uuid, + uuid=local_reply_update.uuid, + source_url="/api/v1/sources/{}".format(source.uuid), + file_counter=local_reply_update.file_counter, + filename=local_reply_update.filename, + ) + + remote_reply_create = factory.RemoteReply( + journalist_uuid=journalist.uuid, + source_url="/api/v1/sources/{}".format(source.uuid), + file_counter=factory.REPLY_COUNT + 1, + filename="{}-filename.gpg".format(factory.REPLY_COUNT + 1), + ) + remote_replies = [remote_reply_update, remote_reply_create] + + session.commit() update_replies(remote_replies, local_replies, session, data_dir) session.commit()