Skip to content

Commit

Permalink
Reduce the amount of database work done on sync
Browse files Browse the repository at this point in the history
Avoid dirtying synced objects on update unless they've actually
changed, and reduce the number of commits.
  • Loading branch information
rmol authored and sssoleileraaa committed Apr 1, 2020
1 parent 2668b9b commit 37a3e5e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 67 deletions.
57 changes: 36 additions & 21 deletions securedrop_client/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shutil
from pathlib import Path
from dateutil.parser import parse
from typing import List, Tuple, Type, Union
from typing import Any, List, Tuple, Type, Union

from sqlalchemy import and_, desc, or_
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -133,8 +133,21 @@ def update_local_storage(session: Session,
update_replies(remote_replies, get_local_replies(session), session, data_dir)


def update_sources(remote_sources: List[SDKSource], local_sources:
List[Source], session: Session, data_dir: str) -> None:
def lazy_setattr(o: Any, a: str, v: Any) -> None:
"""
Only assign v to o.a if they differ.
Intended to avoid unnecessarily dirtying SQLAlchemy objects during
sync.
"""
if getattr(o, a) != v:
setattr(o, a, v)


def update_sources(
remote_sources: List[SDKSource], local_sources: List[Source],
session: Session, data_dir: str
) -> None:
"""
Given collections of remote sources, the current local sources and a
session to the local database, ensure the state of the local database
Expand All @@ -150,15 +163,13 @@ def update_sources(remote_sources: List[SDKSource], local_sources:
if source.uuid in local_sources_by_uuid:
# Update an existing record.
local_source = local_sources_by_uuid[source.uuid]
local_source.journalist_designation = source.journalist_designation
local_source.is_flagged = source.is_flagged
local_source.interaction_count = source.interaction_count
local_source.document_count = source.number_of_documents
local_source.is_starred = source.is_starred
local_source.last_updated = parse(source.last_updated)
local_source.public_key = source.key['public']
local_source.fingerprint = source.key['fingerprint']
session.commit()
lazy_setattr(local_source, "journalist_designation", source.journalist_designation)
lazy_setattr(local_source, "is_flagged", source.is_flagged)
lazy_setattr(local_source, "interaction_count", source.interaction_count)
lazy_setattr(local_source, "document_count", source.number_of_documents)
lazy_setattr(local_source, "is_starred", source.is_starred)
lazy_setattr(local_source, "last_updated", parse(source.last_updated))
lazy_setattr(local_source, "public_key", source.key['public'])

# Removing the UUID from local_sources_by_uuid ensures
# this record won't be deleted at the end of this
Expand All @@ -179,7 +190,6 @@ def update_sources(remote_sources: List[SDKSource], local_sources:
fingerprint=source.key['fingerprint'],
)
session.add(ns)
session.commit()

logger.debug('Added new source {}'.format(source.uuid))

Expand Down Expand Up @@ -222,9 +232,9 @@ def __update_submissions(model: Union[Type[File], Type[Message]],
local_submission = [s for s in local_submissions
if s.uuid == submission.uuid][0]

local_submission.size = submission.size
local_submission.is_read = submission.is_read
local_submission.download_url = submission.download_url
lazy_setattr(local_submission, "size", submission.size)
lazy_setattr(local_submission, "is_read", submission.is_read)
lazy_setattr(local_submission, "download_url", submission.download_url)

# Removing the UUID from local_uuids ensures this record won't be
# deleted at the end of this function.
Expand Down Expand Up @@ -268,8 +278,9 @@ def update_replies(remote_replies: List[SDKReply], local_replies: List[Reply],
local_reply = [r for r in local_replies if r.uuid == reply.uuid][0]

user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session)
local_reply.journalist_id = user.id
local_reply.size = reply.size
lazy_setattr(local_reply, "journalist_id", user.id)
lazy_setattr(local_reply, "size", reply.size)
lazy_setattr(local_reply, "filename", reply.filename)

local_uuids.remove(reply.uuid)
logger.debug('Updated reply {}'.format(reply.uuid))
Expand Down Expand Up @@ -375,8 +386,11 @@ def update_missing_files(data_dir: str, session: Session) -> List[File]:
return files_that_are_missing


def update_draft_replies(session: Session, source_id: int, timestamp: datetime,
old_file_counter: int, new_file_counter: int) -> None:
def update_draft_replies(
session: Session, source_id: int, timestamp: datetime,
old_file_counter: int, new_file_counter: int,
commit: bool = True
) -> None:
"""
When we confirm a sent reply R, if there are drafts that were sent after it,
we need to reposition them to ensure that they appear _after_ the confirmed
Expand Down Expand Up @@ -410,7 +424,8 @@ def update_draft_replies(session: Session, source_id: int, timestamp: datetime,
.all():
draft_reply.file_counter = new_file_counter
session.add(draft_reply)
session.commit()
if commit:
session.commit()


def find_new_files(session: Session) -> List[File]:
Expand Down
86 changes: 40 additions & 46 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def test_update_messages(homedir, mocker):
assert mock_session.commit.call_count == 1


def test_update_replies(homedir, mocker):
def test_update_replies(homedir, mocker, session):
"""
Check that:
Expand All @@ -679,67 +679,61 @@ def test_update_replies(homedir, mocker):
* References to journalist's usernames are correctly handled.
"""
data_dir = os.path.join(homedir, 'data')
mock_session = mocker.MagicMock()
# Source object related to the submissions.
source = mocker.MagicMock()
source.uuid = str(uuid.uuid4())
source.journalist_filename = 'test'

journalist = factory.User(id=1)
session.add(journalist)

source = factory.Source()
session.add(source)

# Some remote reply objects from the API, one of which will exist in the
# local database, the other will NOT exist in the local database
# (this will be added to the database)
remote_reply_update = make_remote_reply(source.uuid)
remote_reply_create = make_remote_reply(source.uuid, 'unknownuser')
remote_reply_update = make_remote_reply(source.uuid, journalist.uuid)
remote_reply_create = make_remote_reply(source.uuid, journalist.uuid)
remote_reply_create.file_counter = 3
remote_reply_create.filename = "3-reply.gpg"

remote_replies = [remote_reply_update, remote_reply_create]

# Some local reply objects. One already exists in the API results
# (this will be updated), one does NOT exist in the API results (this will
# be deleted from the local database).
local_reply_update = mocker.MagicMock()
local_reply_update.uuid = remote_reply_update.uuid
local_filename = "originalsubmissionname.txt"
local_reply_update.filename = local_filename
local_reply_update.journalist_uuid = str(uuid.uuid4())
local_reply_delete = mocker.MagicMock()
local_reply_delete.uuid = str(uuid.uuid4())
local_reply_delete.filename = "local_reply_delete.filename"
local_reply_delete.journalist_uuid = str(uuid.uuid4())
local_reply_delete_source_dir = os.path.join(homedir, source.journalist_filename)
local_reply_delete.location = mocker.MagicMock(
return_value=os.path.join(local_reply_delete_source_dir, local_reply_delete.filename))
local_reply_update = factory.Reply(
uuid=remote_reply_update.uuid,
source_id=source.id,
source=source,
journalist_id=journalist.id,
filename="1-original-reply.gpg.",
size=2,
)
session.add(local_reply_update)

local_reply_delete = factory.Reply(
source_id=source.id,
source=source,
)
session.add(local_reply_delete)

local_replies = [local_reply_update, local_reply_delete]
# There needs to be a corresponding local_source and local_user
local_source = mocker.MagicMock()
local_source.uuid = source.uuid
local_source.id = 666 # };-)
local_user = mocker.MagicMock()
local_user.username = remote_reply_create.journalist_username
local_user.id = 42
mock_session.query().filter_by.side_effect = [[local_source, ],
NoResultFound()]
mock_focu = mocker.MagicMock(return_value=local_user)
mocker.patch('securedrop_client.storage.find_or_create_user', mock_focu)

update_replies(remote_replies, local_replies, mock_session, data_dir)
update_replies(remote_replies, local_replies, session, data_dir)
session.commit()

# Check the expected local reply object has been updated with values
# from the API.
assert local_reply_update.journalist_id == local_user.id
assert local_reply_update.filename == local_filename
assert local_reply_update.journalist_id == journalist.id
assert local_reply_update.size == remote_reply_update.size
assert local_reply_update.filename == remote_reply_update.filename

# Check the expected local source object has been created with values from
# the API.
assert mock_session.add.call_count == 1
new_reply = mock_session.add.call_args_list[0][0][0]
assert new_reply.uuid == remote_reply_create.uuid
assert new_reply.source_id == local_source.id
assert new_reply.journalist_id == local_user.id
new_reply = session.query(db.Reply).filter_by(uuid=remote_reply_create.uuid).one()
assert new_reply.source_id == source.id
assert new_reply.journalist_id == journalist.id
assert new_reply.size == remote_reply_create.size
assert new_reply.filename == remote_reply_create.filename
# Ensure the record for the local source that is missing from the results
# of the API is deleted.
mock_session.delete.assert_called_once_with(local_reply_delete)
# Session is committed to database.
assert mock_session.commit.call_count == 1

# Ensure the local reply that is not in the API results is deleted.
assert session.query(db.Reply).filter_by(uuid=local_reply_delete.uuid).count() == 0


def test_update_replies_cleanup_drafts(homedir, mocker, session):
Expand Down

0 comments on commit 37a3e5e

Please sign in to comment.