From 7cb9a22347eeb3e75396ae69294ed4f4ed31d14b Mon Sep 17 00:00:00 2001 From: John Hensley Date: Tue, 31 Mar 2020 15:32:35 -0400 Subject: [PATCH] Reduce the amount of database work done on sync Avoid dirtying synced objects on update unless they've actually changed, and reduce the number of commits. --- securedrop_client/storage.py | 51 +++++++++++++-------- tests/test_storage.py | 86 +++++++++++++++++------------------- 2 files changed, 72 insertions(+), 65 deletions(-) diff --git a/securedrop_client/storage.py b/securedrop_client/storage.py index 11dc69bd6e..72cf46a8ca 100644 --- a/securedrop_client/storage.py +++ b/securedrop_client/storage.py @@ -25,7 +25,7 @@ import shutil from pathlib import Path from dateutil.parser import parse -from typing import List, Tuple, Type, Union +from typing import Any, List, Tuple, Type, Union from sqlalchemy import and_, desc, or_ from sqlalchemy.orm.exc import NoResultFound @@ -135,6 +135,17 @@ def update_local_storage(session: Session, update_replies(remote_replies, get_local_replies(session), session, data_dir) +def lazy_setattr(o: Any, a: str, v: Any) -> None: + """ + Only assign v to o.a if they differ. + + Intended to avoid unnecessarily dirtying SQLAlchemy objects during + sync. + """ + if getattr(o, a) != v: + setattr(o, a, v) + + def update_sources(gpg: GpgHelper, remote_sources: List[SDKSource], local_sources: List[Source], session: Session, data_dir: str) -> None: """ @@ -152,15 +163,13 @@ def update_sources(gpg: GpgHelper, remote_sources: List[SDKSource], if source.uuid in local_sources_by_uuid: # Update an existing record. local_source = local_sources_by_uuid[source.uuid] - local_source.journalist_designation = source.journalist_designation - local_source.is_flagged = source.is_flagged - local_source.interaction_count = source.interaction_count - local_source.document_count = source.number_of_documents - local_source.is_starred = source.is_starred - local_source.last_updated = parse(source.last_updated) - local_source.public_key = source.key['public'] - local_source.fingerprint = source.key['fingerprint'] - session.commit() + lazy_setattr(local_source, "journalist_designation", source.journalist_designation) + lazy_setattr(local_source, "is_flagged", source.is_flagged) + lazy_setattr(local_source, "interaction_count", source.interaction_count) + lazy_setattr(local_source, "document_count", source.number_of_documents) + lazy_setattr(local_source, "is_starred", source.is_starred) + lazy_setattr(local_source, "last_updated", parse(source.last_updated)) + lazy_setattr(local_source, "public_key", source.key['public']) # Removing the UUID from local_sources_by_uuid ensures # this record won't be deleted at the end of this @@ -181,7 +190,6 @@ def update_sources(gpg: GpgHelper, remote_sources: List[SDKSource], fingerprint=source.key['fingerprint'], ) session.add(ns) - session.commit() logger.debug('Added new source {}'.format(source.uuid)) @@ -224,9 +232,9 @@ def __update_submissions(model: Union[Type[File], Type[Message]], local_submission = [s for s in local_submissions if s.uuid == submission.uuid][0] - local_submission.size = submission.size - local_submission.is_read = submission.is_read - local_submission.download_url = submission.download_url + lazy_setattr(local_submission, "size", submission.size) + lazy_setattr(local_submission, "is_read", submission.is_read) + lazy_setattr(local_submission, "download_url", submission.download_url) # Removing the UUID from local_uuids ensures this record won't be # deleted at the end of this function. @@ -270,8 +278,9 @@ def update_replies(remote_replies: List[SDKReply], local_replies: List[Reply], local_reply = [r for r in local_replies if r.uuid == reply.uuid][0] user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session) - local_reply.journalist_id = user.id - local_reply.size = reply.size + lazy_setattr(local_reply, "journalist_id", user.id) + lazy_setattr(local_reply, "size", reply.size) + lazy_setattr(local_reply, "filename", reply.filename) local_uuids.remove(reply.uuid) logger.debug('Updated reply {}'.format(reply.uuid)) @@ -377,8 +386,11 @@ def update_missing_files(data_dir: str, session: Session) -> List[File]: return files_that_are_missing -def update_draft_replies(session: Session, source_id: int, timestamp: datetime, - old_file_counter: int, new_file_counter: int) -> None: +def update_draft_replies( + session: Session, source_id: int, timestamp: datetime, + old_file_counter: int, new_file_counter: int, + commit: bool = True +) -> None: """ When we confirm a sent reply R, if there are drafts that were sent after it, we need to reposition them to ensure that they appear _after_ the confirmed @@ -412,7 +424,8 @@ def update_draft_replies(session: Session, source_id: int, timestamp: datetime, .all(): draft_reply.file_counter = new_file_counter session.add(draft_reply) - session.commit() + if commit: + session.commit() def find_new_files(session: Session) -> List[File]: diff --git a/tests/test_storage.py b/tests/test_storage.py index 4d56a061de..3e7b418439 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -676,7 +676,7 @@ def test_update_messages(homedir, mocker): assert mock_session.commit.call_count == 1 -def test_update_replies(homedir, mocker): +def test_update_replies(homedir, mocker, session): """ Check that: @@ -687,67 +687,61 @@ def test_update_replies(homedir, mocker): * References to journalist's usernames are correctly handled. """ data_dir = os.path.join(homedir, 'data') - mock_session = mocker.MagicMock() - # Source object related to the submissions. - source = mocker.MagicMock() - source.uuid = str(uuid.uuid4()) - source.journalist_filename = 'test' + + journalist = factory.User(id=1) + session.add(journalist) + + source = factory.Source() + session.add(source) + # Some remote reply objects from the API, one of which will exist in the # local database, the other will NOT exist in the local database # (this will be added to the database) - remote_reply_update = make_remote_reply(source.uuid) - remote_reply_create = make_remote_reply(source.uuid, 'unknownuser') + remote_reply_update = make_remote_reply(source.uuid, journalist.uuid) + remote_reply_create = make_remote_reply(source.uuid, journalist.uuid) + remote_reply_create.file_counter = 3 + remote_reply_create.filename = "3-reply.gpg" + remote_replies = [remote_reply_update, remote_reply_create] + # Some local reply objects. One already exists in the API results # (this will be updated), one does NOT exist in the API results (this will # be deleted from the local database). - local_reply_update = mocker.MagicMock() - local_reply_update.uuid = remote_reply_update.uuid - local_filename = "originalsubmissionname.txt" - local_reply_update.filename = local_filename - local_reply_update.journalist_uuid = str(uuid.uuid4()) - local_reply_delete = mocker.MagicMock() - local_reply_delete.uuid = str(uuid.uuid4()) - local_reply_delete.filename = "local_reply_delete.filename" - local_reply_delete.journalist_uuid = str(uuid.uuid4()) - local_reply_delete_source_dir = os.path.join(homedir, source.journalist_filename) - local_reply_delete.location = mocker.MagicMock( - return_value=os.path.join(local_reply_delete_source_dir, local_reply_delete.filename)) + local_reply_update = factory.Reply( + uuid=remote_reply_update.uuid, + source_id=source.id, + source=source, + journalist_id=journalist.id, + filename="1-original-reply.gpg.", + size=2, + ) + session.add(local_reply_update) + + local_reply_delete = factory.Reply( + source_id=source.id, + source=source, + ) + session.add(local_reply_delete) + local_replies = [local_reply_update, local_reply_delete] - # There needs to be a corresponding local_source and local_user - local_source = mocker.MagicMock() - local_source.uuid = source.uuid - local_source.id = 666 # };-) - local_user = mocker.MagicMock() - local_user.username = remote_reply_create.journalist_username - local_user.id = 42 - mock_session.query().filter_by.side_effect = [[local_source, ], - NoResultFound()] - mock_focu = mocker.MagicMock(return_value=local_user) - mocker.patch('securedrop_client.storage.find_or_create_user', mock_focu) - update_replies(remote_replies, local_replies, mock_session, data_dir) + update_replies(remote_replies, local_replies, session, data_dir) + session.commit() # Check the expected local reply object has been updated with values # from the API. - assert local_reply_update.journalist_id == local_user.id - assert local_reply_update.filename == local_filename + assert local_reply_update.journalist_id == journalist.id assert local_reply_update.size == remote_reply_update.size + assert local_reply_update.filename == remote_reply_update.filename - # Check the expected local source object has been created with values from - # the API. - assert mock_session.add.call_count == 1 - new_reply = mock_session.add.call_args_list[0][0][0] - assert new_reply.uuid == remote_reply_create.uuid - assert new_reply.source_id == local_source.id - assert new_reply.journalist_id == local_user.id + new_reply = session.query(db.Reply).filter_by(uuid=remote_reply_create.uuid).one() + assert new_reply.source_id == source.id + assert new_reply.journalist_id == journalist.id assert new_reply.size == remote_reply_create.size assert new_reply.filename == remote_reply_create.filename - # Ensure the record for the local source that is missing from the results - # of the API is deleted. - mock_session.delete.assert_called_once_with(local_reply_delete) - # Session is committed to database. - assert mock_session.commit.call_count == 1 + + # Ensure the local reply that is not in the API results is deleted. + assert session.query(db.Reply).filter_by(uuid=local_reply_delete.uuid).count() == 0 def test_update_replies_cleanup_drafts(homedir, mocker, session):