From 535b1d9603e8749717fc60e6243d6c2f7ea170ee Mon Sep 17 00:00:00 2001 From: heartsucker Date: Tue, 7 May 2019 19:44:50 +0200 Subject: [PATCH] added timeouts to requests --- sdclientapi/__init__.py | 151 ++++++++++++++++++++++++++++++++-------- 1 file changed, 123 insertions(+), 28 deletions(-) diff --git a/sdclientapi/__init__.py b/sdclientapi/__init__.py index 6452290..6a8e7da 100644 --- a/sdclientapi/__init__.py +++ b/sdclientapi/__init__.py @@ -3,7 +3,8 @@ import os import requests from datetime import datetime -from subprocess import PIPE, Popen +from requests.exceptions import ConnectTimeout, ReadTimeout +from subprocess import PIPE, Popen, TimeoutExpired from typing import List, Tuple, Dict, Optional, Any from urllib.parse import urljoin @@ -18,9 +19,19 @@ ) DEFAULT_PROXY_VM_NAME = "sd-proxy" +DEFAULT_REQUEST_TIMEOUT = 20 # 20 seconds +DEFAULT_DOWNLOAD_TIMEOUT = 60 * 60 # 60 minutes -def json_query(proxy_vm_name: str, data: str) -> str: +class RequestTimeoutError(Exception): + """ + Error raisted if a request times out. + """ + + pass + + +def json_query(proxy_vm_name: str, data: str, timeout: Optional[int] = None) -> str: """ Takes a json based query and passes to the network proxy. Returns the JSON output from the proxy. @@ -32,9 +43,18 @@ def json_query(proxy_vm_name: str, data: str) -> str: stderr=PIPE, ) p.stdin.write(data.encode("utf-8")) - stdout, _ = p.communicate() # type: (bytes, bytes) - output = stdout.decode("utf-8") - return output.strip() + + try: + stdout, _ = p.communicate(timeout=timeout) # type: (bytes, bytes) + except TimeoutExpired: + try: + p.terminate() + except Exception: + pass + raise RequestTimeoutError + else: + output = stdout.decode("utf-8") + return output.strip() class API: @@ -45,6 +65,9 @@ class API: :param username: Journalist username :param passphrase: Journalist passphrase :param totp: Current TOTP value + :param proxy: Whether the API class should use the RPC proxy + :param default_request_timeout: Default timeout for a request (non-download) in seconds + :param default_download_timeout: Default timeout for a request (download only) in seconds :returns: An object of API class. """ @@ -55,6 +78,8 @@ def __init__( passphrase: str, totp: str, proxy: bool = False, + default_request_timeout: Optional[int] = None, + default_download_timeout: Optional[int] = None, ) -> None: """ Primary API class, this is the only thing which will make network call. @@ -68,6 +93,12 @@ def __init__( self.token_journalist_uuid = None # type: Optional[str] self.req_headers = dict() # type: Dict[str, str] self.proxy = proxy # type: bool + self.default_request_timeout = ( + default_request_timeout or DEFAULT_REQUEST_TIMEOUT + ) + self.default_download_timeout = ( + default_download_timeout or DEFAULT_DOWNLOAD_TIMEOUT + ) self.proxy_vm_name = DEFAULT_PROXY_VM_NAME config = configparser.ConfigParser() @@ -84,13 +115,14 @@ def _send_json_request( path_query: str, body: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, ) -> Tuple[Any, int, Dict[str, str]]: if self.proxy: # We are using the Qubes securedrop-proxy func = self._send_rpc_json_request else: # We are not using the Qubes securedrop-proxy func = self._send_http_json_request - return func(method, path_query, body, headers) + return func(method, path_query, body, headers, timeout) def _send_http_json_request( self, @@ -98,14 +130,21 @@ def _send_http_json_request( path_query: str, body: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, ) -> Tuple[Any, int, Dict[str, str]]: url = urljoin(self.server, path_query) kwargs = {"headers": headers} # type: Dict[str, Any] + if timeout: + kwargs["timeout"] = timeout + if method == "POST": kwargs["data"] = body - result = requests.request(method, url, **kwargs) + try: + result = requests.request(method, url, **kwargs) + except (ConnectTimeout, ReadTimeout): + raise RequestTimeoutError # Because when we download a file there is no JSON in the body if path_query.endswith("/download"): @@ -119,6 +158,7 @@ def _send_rpc_json_request( path_query: str, body: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, ) -> Tuple[Any, int, Dict[str, str]]: data = {"method": method, "path_query": path_query} # type: Dict[str, Any] @@ -128,8 +168,8 @@ def _send_rpc_json_request( if headers is not None and headers: data["headers"] = headers - data_str = json.dumps(data, sort_keys=True) - result = json.loads(json_query(self.proxy_vm_name, data_str)) + data_str = json.dumps(data) + result = json.loads(json_query(self.proxy_vm_name, data_str, timeout)) return json.loads(result["body"]), result["status"], result["headers"] def authenticate(self, totp: Optional[str] = None) -> bool: @@ -153,10 +193,11 @@ def authenticate(self, totp: Optional[str] = None) -> bool: try: token_data, status_code, headers = self._send_json_request( - method, path_query, body=body + method, path_query, body=body, timeout=self.default_request_timeout ) except json.decoder.JSONDecodeError: raise BaseError("Error in parsing JSON") + if "expiration" not in token_data: raise AuthError("Authentication error") @@ -189,7 +230,10 @@ def get_sources(self) -> List[Source]: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) except json.decoder.JSONDecodeError: raise BaseError("Error in parsing JSON") @@ -218,7 +262,10 @@ def get_source(self, source: Source) -> Source: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -256,7 +303,10 @@ def delete_source(self, source: Source) -> bool: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -298,7 +348,10 @@ def add_star(self, source: Source) -> bool: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: raise WrongUUIDError("Missing source {}".format(source.uuid)) @@ -321,7 +374,10 @@ def remove_star(self, source: Source) -> bool: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: raise WrongUUIDError("Missing source {}".format(source.uuid)) @@ -345,7 +401,10 @@ def get_submissions(self, source: Source) -> List[Submission]: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -380,7 +439,10 @@ def get_submission(self, submission: Submission) -> Submission: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -417,7 +479,10 @@ def get_all_submissions(self) -> List[Submission]: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) except json.decoder.JSONDecodeError: raise BaseError("Error in parsing JSON") @@ -451,7 +516,10 @@ def delete_submission(self, submission: Submission) -> bool: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -493,7 +561,9 @@ def download_submission( :returns: Tuple of sha256sum and path of the saved submission. """ path_query = "api/v1/sources/{}/submissions/{}/download".format( - submission.source_uuid, submission.uuid + submission.source_uuid, + submission.uuid, + timeout=self.default_download_timeout, ) method = "GET" @@ -546,7 +616,10 @@ def flag_source(self, source: Source) -> bool: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -577,7 +650,10 @@ def get_current_user(self) -> Any: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) except json.decoder.JSONDecodeError: @@ -608,7 +684,11 @@ def reply_source( try: data, status_code, headers = self._send_json_request( - method, path_query, body=json.dumps(reply), headers=self.req_headers + method, + path_query, + body=json.dumps(reply), + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 400: @@ -637,7 +717,10 @@ def get_replies_from_source(self, source: Source) -> List[Reply]: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -669,7 +752,10 @@ def get_reply_from_source(self, source: Source, reply_uuid: str) -> Reply: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -696,7 +782,10 @@ def get_all_replies(self) -> List[Reply]: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) except json.decoder.JSONDecodeError: @@ -734,7 +823,10 @@ def download_reply(self, reply: Reply, path: str = "") -> Tuple[str, str]: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: @@ -783,7 +875,10 @@ def delete_reply(self, reply: Reply) -> bool: try: data, status_code, headers = self._send_json_request( - method, path_query, headers=self.req_headers + method, + path_query, + headers=self.req_headers, + timeout=self.default_request_timeout, ) if status_code == 404: