Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimize sync db work #1036

Merged
merged 1 commit into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions securedrop_client/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shutil
from pathlib import Path
from dateutil.parser import parse
from typing import List, Tuple, Type, Union
from typing import Any, List, Tuple, Type, Union

from sqlalchemy import and_, desc, or_
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -133,8 +133,21 @@ def update_local_storage(session: Session,
update_replies(remote_replies, get_local_replies(session), session, data_dir)


def update_sources(remote_sources: List[SDKSource], local_sources:
List[Source], session: Session, data_dir: str) -> None:
def lazy_setattr(o: Any, a: str, v: Any) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minimal names!

"""
Only assign v to o.a if they differ.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


Intended to avoid unnecessarily dirtying SQLAlchemy objects during
sync.
"""
if getattr(o, a) != v:
setattr(o, a, v)


def update_sources(
remote_sources: List[SDKSource], local_sources: List[Source],
session: Session, data_dir: str
) -> None:
"""
Given collections of remote sources, the current local sources and a
session to the local database, ensure the state of the local database
Expand All @@ -150,15 +163,13 @@ def update_sources(remote_sources: List[SDKSource], local_sources:
if source.uuid in local_sources_by_uuid:
# Update an existing record.
local_source = local_sources_by_uuid[source.uuid]
local_source.journalist_designation = source.journalist_designation
local_source.is_flagged = source.is_flagged
local_source.interaction_count = source.interaction_count
local_source.document_count = source.number_of_documents
local_source.is_starred = source.is_starred
local_source.last_updated = parse(source.last_updated)
local_source.public_key = source.key['public']
local_source.fingerprint = source.key['fingerprint']
session.commit()
lazy_setattr(local_source, "journalist_designation", source.journalist_designation)
lazy_setattr(local_source, "is_flagged", source.is_flagged)
lazy_setattr(local_source, "interaction_count", source.interaction_count)
lazy_setattr(local_source, "document_count", source.number_of_documents)
lazy_setattr(local_source, "is_starred", source.is_starred)
lazy_setattr(local_source, "last_updated", parse(source.last_updated))
lazy_setattr(local_source, "public_key", source.key['public'])

# Removing the UUID from local_sources_by_uuid ensures
# this record won't be deleted at the end of this
Expand All @@ -179,7 +190,6 @@ def update_sources(remote_sources: List[SDKSource], local_sources:
fingerprint=source.key['fingerprint'],
)
session.add(ns)
session.commit()

logger.debug('Added new source {}'.format(source.uuid))

Expand Down Expand Up @@ -222,9 +232,9 @@ def __update_submissions(model: Union[Type[File], Type[Message]],
local_submission = [s for s in local_submissions
if s.uuid == submission.uuid][0]

local_submission.size = submission.size
local_submission.is_read = submission.is_read
local_submission.download_url = submission.download_url
lazy_setattr(local_submission, "size", submission.size)
lazy_setattr(local_submission, "is_read", submission.is_read)
lazy_setattr(local_submission, "download_url", submission.download_url)

# Removing the UUID from local_uuids ensures this record won't be
# deleted at the end of this function.
Expand Down Expand Up @@ -268,8 +278,9 @@ def update_replies(remote_replies: List[SDKReply], local_replies: List[Reply],
local_reply = [r for r in local_replies if r.uuid == reply.uuid][0]

user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session)
local_reply.journalist_id = user.id
local_reply.size = reply.size
lazy_setattr(local_reply, "journalist_id", user.id)
lazy_setattr(local_reply, "size", reply.size)
lazy_setattr(local_reply, "filename", reply.filename)

local_uuids.remove(reply.uuid)
logger.debug('Updated reply {}'.format(reply.uuid))
Expand Down Expand Up @@ -375,8 +386,11 @@ def update_missing_files(data_dir: str, session: Session) -> List[File]:
return files_that_are_missing


def update_draft_replies(session: Session, source_id: int, timestamp: datetime,
old_file_counter: int, new_file_counter: int) -> None:
def update_draft_replies(
session: Session, source_id: int, timestamp: datetime,
old_file_counter: int, new_file_counter: int,
commit: bool = True
) -> None:
"""
When we confirm a sent reply R, if there are drafts that were sent after it,
we need to reposition them to ensure that they appear _after_ the confirmed
Expand Down Expand Up @@ -410,7 +424,8 @@ def update_draft_replies(session: Session, source_id: int, timestamp: datetime,
.all():
draft_reply.file_counter = new_file_counter
session.add(draft_reply)
session.commit()
if commit:
session.commit()


def find_new_files(session: Session) -> List[File]:
Expand Down
86 changes: 40 additions & 46 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def test_update_messages(homedir, mocker):
assert mock_session.commit.call_count == 1


def test_update_replies(homedir, mocker):
def test_update_replies(homedir, mocker, session):
"""
Check that:

Expand All @@ -679,67 +679,61 @@ def test_update_replies(homedir, mocker):
* References to journalist's usernames are correctly handled.
"""
data_dir = os.path.join(homedir, 'data')
mock_session = mocker.MagicMock()
# Source object related to the submissions.
source = mocker.MagicMock()
source.uuid = str(uuid.uuid4())
source.journalist_filename = 'test'

journalist = factory.User(id=1)
session.add(journalist)

source = factory.Source()
session.add(source)

# Some remote reply objects from the API, one of which will exist in the
# local database, the other will NOT exist in the local database
# (this will be added to the database)
remote_reply_update = make_remote_reply(source.uuid)
remote_reply_create = make_remote_reply(source.uuid, 'unknownuser')
remote_reply_update = make_remote_reply(source.uuid, journalist.uuid)
remote_reply_create = make_remote_reply(source.uuid, journalist.uuid)
remote_reply_create.file_counter = 3
remote_reply_create.filename = "3-reply.gpg"

remote_replies = [remote_reply_update, remote_reply_create]

# Some local reply objects. One already exists in the API results
# (this will be updated), one does NOT exist in the API results (this will
# be deleted from the local database).
local_reply_update = mocker.MagicMock()
local_reply_update.uuid = remote_reply_update.uuid
local_filename = "originalsubmissionname.txt"
local_reply_update.filename = local_filename
local_reply_update.journalist_uuid = str(uuid.uuid4())
local_reply_delete = mocker.MagicMock()
local_reply_delete.uuid = str(uuid.uuid4())
local_reply_delete.filename = "local_reply_delete.filename"
local_reply_delete.journalist_uuid = str(uuid.uuid4())
local_reply_delete_source_dir = os.path.join(homedir, source.journalist_filename)
local_reply_delete.location = mocker.MagicMock(
return_value=os.path.join(local_reply_delete_source_dir, local_reply_delete.filename))
local_reply_update = factory.Reply(
uuid=remote_reply_update.uuid,
source_id=source.id,
source=source,
journalist_id=journalist.id,
filename="1-original-reply.gpg.",
size=2,
)
session.add(local_reply_update)

local_reply_delete = factory.Reply(
source_id=source.id,
source=source,
)
session.add(local_reply_delete)

local_replies = [local_reply_update, local_reply_delete]
# There needs to be a corresponding local_source and local_user
local_source = mocker.MagicMock()
local_source.uuid = source.uuid
local_source.id = 666 # };-)
local_user = mocker.MagicMock()
local_user.username = remote_reply_create.journalist_username
local_user.id = 42
mock_session.query().filter_by.side_effect = [[local_source, ],
NoResultFound()]
mock_focu = mocker.MagicMock(return_value=local_user)
mocker.patch('securedrop_client.storage.find_or_create_user', mock_focu)

update_replies(remote_replies, local_replies, mock_session, data_dir)
update_replies(remote_replies, local_replies, session, data_dir)
session.commit()

# Check the expected local reply object has been updated with values
# from the API.
assert local_reply_update.journalist_id == local_user.id
assert local_reply_update.filename == local_filename
assert local_reply_update.journalist_id == journalist.id
assert local_reply_update.size == remote_reply_update.size
assert local_reply_update.filename == remote_reply_update.filename

# Check the expected local source object has been created with values from
# the API.
assert mock_session.add.call_count == 1
new_reply = mock_session.add.call_args_list[0][0][0]
assert new_reply.uuid == remote_reply_create.uuid
assert new_reply.source_id == local_source.id
assert new_reply.journalist_id == local_user.id
new_reply = session.query(db.Reply).filter_by(uuid=remote_reply_create.uuid).one()
assert new_reply.source_id == source.id
assert new_reply.journalist_id == journalist.id
assert new_reply.size == remote_reply_create.size
assert new_reply.filename == remote_reply_create.filename
# Ensure the record for the local source that is missing from the results
# of the API is deleted.
mock_session.delete.assert_called_once_with(local_reply_delete)
# Session is committed to database.
assert mock_session.commit.call_count == 1

# Ensure the local reply that is not in the API results is deleted.
assert session.query(db.Reply).filter_by(uuid=local_reply_delete.uuid).count() == 0


def test_update_replies_cleanup_drafts(homedir, mocker, session):
Expand Down