Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up threading #77

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 121 additions & 38 deletions securedrop_client/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from securedrop_client.utils import check_dir_permissions
from PyQt5.QtCore import QObject, QThread, pyqtSignal, QTimer


logger = logging.getLogger(__name__)


Expand All @@ -36,7 +35,6 @@ class APICallRunner(QObject):
"""

call_finished = pyqtSignal(bool) # Indicates there is a result.
timeout = pyqtSignal() # Indicates there was a timeout.

def __init__(self, api_call, *args, **kwargs):
"""
Expand All @@ -48,33 +46,30 @@ def __init__(self, api_call, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.result = None
self.i_timed_out = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flag telling the API thread that it's been timed out, and when it finally gets unblocked it should Do Nothing.


def call_api(self):
"""
Call the API. Emit a boolean signal to indicate the outcome of the
call. Timeout signal emitted after 5 seconds. Any return value or
exception raised is stored in self.result.
call. Any return value or exception raised is stored in self.result.
"""
self.timer = QTimer()
self.timer.timeout.connect(lambda: self.timeout.emit())
self.timer.setSingleShot(True)
self.timer.start(5000)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the timer stuff moved to the Client object.


# this blocks
try:
logger.info('Calling API with "{}" method'.format(
self.api_call.__name__))
self.result = self.api_call(*self.args, **self.kwargs)
result_flag = bool(self.result)
except Exception as ex:
logger.error(ex)
self.result = ex
result_flag = False
self.call_finished.emit(result_flag)

def on_cancel_timeout(self):
"""
Handles a signal to indicate the timer should stop.
"""
self.timer.stop()
# by the time we end up here, who knows how long it's taken
# we may not want to emit this, if there's nothing to catch it
if self.i_timed_out is False:
self.call_finished.emit(result_flag)
else:
logger.info("Thread returned from API call, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this is basically just for informational purposes for devs right now so this one line I no covered

"but it had timed out.") # pragma: no cover


class Client(QObject):
Expand All @@ -83,7 +78,7 @@ class Client(QObject):
application, this is the controller.
"""

finish_api_call = pyqtSignal() # Acknowledges reciept of an API call.
timeout_api_call = pyqtSignal() # Indicates there was a timeout.

def __init__(self, hostname, gui, session, home: str) -> None:
"""
Expand All @@ -103,6 +98,7 @@ def __init__(self, hostname, gui, session, home: str) -> None:
self.sync_flag = os.path.join(home, 'sync_flag')
self.home = home # The "home" directory for client files.
self.data_dir = os.path.join(self.home, 'data') # File data.
self.timer = None # call timeout timer

def setup(self):
"""
Expand Down Expand Up @@ -133,41 +129,70 @@ def call_api(self, function, callback, timeout, *args, current_object=None,
timeout signal. Any further arguments are passed to the function to be
called.
"""

if not self.api_thread:
self.timer = QTimer()
self.timer.timeout.connect(lambda: self.timeout_api_call.emit())
self.timer.setSingleShot(True)
self.timer.start(20000)

self.api_thread = QThread(self.gui)
self.api_runner = APICallRunner(function, *args, **kwargs)
self.api_runner.moveToThread(self.api_thread)
self.api_runner.current_object = current_object
self.api_runner.call_finished.connect(callback)
self.api_runner.timeout.connect(timeout)
self.finish_api_call.connect(self.api_runner.on_cancel_timeout)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled in our callback wrapper now.


# handle successful call: copy response data, reset the
# client, give the user-provided callback the response
# data
self.api_runner.call_finished.connect(
lambda r: self.successful_api_call(r, callback))

# we've started a timer. when that hits zero, call our
# timeout function
self.timeout_api_call.connect(
lambda: self.timeout_cleanup(timeout))

# when the thread starts, we want to run `call_api` on `api_runner`
self.api_thread.started.connect(self.api_runner.call_api)
self.api_thread.finished.connect(self.call_reset)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also handled in our callback wrapper.


self.api_thread.start()

else:
logger.info("Concurrent API requests are not implemented yet and "
"an API request is already running.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the log message implies, we should consider how we actually want to handle "concurrent" API requests...

def call_reset(self):
"""
Clean up this object's state after an API call.
"""
if self.api_thread:
self.finish_api_call.emit()
self.timeout_api_call.disconnect()
self.api_runner = None
self.api_thread = None
self.timer = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_reset is called by our callback wrappers. We take care to disconnect the timeout handler callbacks: if we don't, then we'll just add more and more callbacks over time, and when one call finally times out, all collected timeout handlers will get triggered (even for requests that have long-since completed).


def login(self, username, password, totp):
"""
Given a username, password and time based one-time-passcode (TOTP),
create a new instance representing the SecureDrop api and authenticate.
"""

self.api = sdclientapi.API(self.hostname, username, password, totp)

self.call_api(self.api.authenticate, self.on_authenticate,
self.on_login_timeout)

def on_authenticate(self, result):
def on_cancel_timeout(self):
"""
Handles a signal to indicate the timer should stop.
"""
self.timer.stop()

def on_authenticate(self, result, result_data):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note the new call signature, adding result_data (unused in this method, as it turns out)

"""
Handles the result of an authentication call against the API.
"""
self.call_reset()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now handled in our callback wrapper


if result:
# It worked! Sync with the API and update the UI.
self.gui.hide_login()
Expand All @@ -182,15 +207,65 @@ def on_authenticate(self, result):
error = _('There was a problem logging in. Please try again.')
self.gui.show_login_error(error=error)

def successful_api_call(self, r, user_callback):
logger.info("Successful API call. Cleaning up and running callback.")

self.timer.stop()
result_data = self.api_runner.result

# The callback may or may not have an associated current_object
if self.api_runner.current_object:
current_object = self.api_runner.current_object
else:
current_object = None

self.call_reset()

if current_object:
user_callback(r, result_data, current_object=current_object)
else:
user_callback(r, result_data)

def timeout_cleanup(self, user_callback):
logger.info("API call timed out. Cleaning up and running "
"timeout callback.")

if self.api_thread:
self.api_runner.i_timed_out = True

if self.api_runner.current_object:
current_object = self.api_runner.current_object
else:
current_object = None

self.call_reset()

if current_object:
user_callback(current_object=current_object)
else:
user_callback()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are our callback wrappers, which reset the state of the Client before calling user callbacks.

def on_login_timeout(self):
"""
Reset the form and indicate the error.
"""
self.call_reset()

self.api = None
error = _('The connection to SecureDrop timed out. Please try again.')
self.gui.show_login_error(error=error)

def on_sync_timeout(self):
"""
Indicate that a sync failed.

TODO: We don't really want to alert in the error bar _every time_
this happens. Instead, we should do something like: alert if there
have been many timeouts in a row.
"""

error = _('The connection to SecureDrop timed out. Please try again.')
self.gui.update_error_status(error=error)

def on_action_requiring_login(self):
"""
Indicate that a user needs to login to perform the specified action.
Expand All @@ -217,9 +292,15 @@ def sync_api(self):
"""
Grab data from the remote SecureDrop API in a non-blocking manner.
"""
logger.debug("In sync_api on thread {}".format(
self.thread().currentThreadId()))

if self.authenticated():
logger.debug("You are authenticated, going to make your call")
self.call_api(storage.get_remote_data, self.on_synced,
self.on_login_timeout, self.api)
self.on_sync_timeout, self.api)
logger.debug("In sync_api, after call to call_api, on "
"thread {}".format(self.thread().currentThreadId()))

def last_sync(self):
"""
Expand All @@ -231,17 +312,19 @@ def last_sync(self):
except Exception:
return None

def on_synced(self, result):
def on_synced(self, result, result_data):
"""
Called when syncronisation of data via the API is complete.
"""
if result and isinstance(self.api_runner.result, tuple):

if result and isinstance(result_data, tuple):
remote_sources, remote_submissions, remote_replies = \
self.api_runner.result
self.call_reset()
result_data

storage.update_local_storage(self.session, remote_sources,
remote_submissions,
remote_replies)

# Set last sync flag.
with open(self.sync_flag, 'w') as f:
f.write(arrow.now().format())
Expand All @@ -251,6 +334,7 @@ def on_synced(self, result):
# How to handle a failure? Exceptions are already logged. Perhaps
# a message in the UI?
pass

self.update_sources()

def update_sync(self):
Expand All @@ -267,14 +351,14 @@ def update_sources(self):
self.gui.show_sources(sources)
self.update_sync()

def on_update_star_complete(self, result):
def on_update_star_complete(self, result, result_data):
"""
After we star or unstar a source, we should sync the API
such that the local database is updated.

TODO: Improve the push to server sync logic.
"""
self.call_reset()

if result:
self.sync_api() # Syncing the API also updates the source list UI
self.gui.update_error_status("")
Expand Down Expand Up @@ -340,16 +424,15 @@ def on_file_click(self, source_db_object, message):
self.on_download_timeout, sdk_object, self.data_dir,
current_object=message)

def on_file_download(self, result):
def on_file_download(self, result, result_data, current_object):
"""
Called when a file has downloaded. Cause a refresh to the conversation
view to display the contents of the new file.
"""
sha256sum, filename = self.api_runner.result
file_uuid = self.api_runner.current_object.uuid
server_filename = self.api_runner.current_object.filename
self.call_reset()
file_uuid = current_object.uuid
server_filename = current_object.filename
if result:
sha256sum, filename = result_data
# The filename contains the location where the file has been
# stored. On non-Qubes OSes, this will be the data directory.
# On Qubes OS, this will a ~/QubesIncoming directory. In case
Expand All @@ -365,7 +448,7 @@ def on_file_download(self, result):
# Update the UI in some way to indicate a failure state.
self.set_status("Failed to download file, please try again.")

def on_download_timeout(self):
def on_download_timeout(self, current_object):
"""
Called when downloading a file has timed out.
"""
Expand Down
Loading