Skip to content
This repository has been archived by the owner on Jan 7, 2024. It is now read-only.

Commit

Permalink
added timeouts to requests
Browse files Browse the repository at this point in the history
  • Loading branch information
heartsucker committed May 13, 2019
1 parent 400aeb3 commit 4227e87
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 30 deletions.
155 changes: 126 additions & 29 deletions sdclientapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import requests
from datetime import datetime
from subprocess import PIPE, Popen
from requests.exceptions import ConnectTimeout, ReadTimeout
from subprocess import PIPE, Popen, TimeoutExpired
from typing import List, Tuple, Dict, Optional, Any
from urllib.parse import urljoin

Expand All @@ -18,9 +19,20 @@
)

DEFAULT_PROXY_VM_NAME = "sd-proxy"
DEFAULT_REQUEST_TIMEOUT = 20 # 20 seconds
DEFAULT_DOWNLOAD_TIMEOUT = 60 * 60 # 60 minutes


def json_query(proxy_vm_name: str, data: str) -> str:
class RequestTimeoutError(Exception):
"""
Error raisted if a request times out.
"""

def __init__(self) -> None:
super().__init__("The request timed out.")


def json_query(proxy_vm_name: str, data: str, timeout: Optional[int] = None) -> str:
"""
Takes a json based query and passes to the network proxy.
Returns the JSON output from the proxy.
Expand All @@ -32,9 +44,18 @@ def json_query(proxy_vm_name: str, data: str) -> str:
stderr=PIPE,
)
p.stdin.write(data.encode("utf-8"))
stdout, _ = p.communicate() # type: (bytes, bytes)
output = stdout.decode("utf-8")
return output.strip()

try:
stdout, _ = p.communicate(timeout=timeout) # type: (bytes, bytes)
except TimeoutExpired:
try:
p.terminate()
except Exception:
pass
raise RequestTimeoutError
else:
output = stdout.decode("utf-8")
return output.strip()


class API:
Expand All @@ -45,6 +66,9 @@ class API:
:param username: Journalist username
:param passphrase: Journalist passphrase
:param totp: Current TOTP value
:param proxy: Whether the API class should use the RPC proxy
:param default_request_timeout: Default timeout for a request (non-download) in seconds
:param default_download_timeout: Default timeout for a request (download only) in seconds
:returns: An object of API class.
"""

Expand All @@ -55,6 +79,8 @@ def __init__(
passphrase: str,
totp: str,
proxy: bool = False,
default_request_timeout: Optional[int] = None,
default_download_timeout: Optional[int] = None,
) -> None:
"""
Primary API class, this is the only thing which will make network call.
Expand All @@ -68,6 +94,12 @@ def __init__(
self.token_journalist_uuid = None # type: Optional[str]
self.req_headers = dict() # type: Dict[str, str]
self.proxy = proxy # type: bool
self.default_request_timeout = (
default_request_timeout or DEFAULT_REQUEST_TIMEOUT
)
self.default_download_timeout = (
default_download_timeout or DEFAULT_DOWNLOAD_TIMEOUT
)

self.proxy_vm_name = DEFAULT_PROXY_VM_NAME
config = configparser.ConfigParser()
Expand All @@ -84,28 +116,36 @@ def _send_json_request(
path_query: str,
body: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Tuple[Any, int, Dict[str, str]]:
if self.proxy: # We are using the Qubes securedrop-proxy
func = self._send_rpc_json_request
else: # We are not using the Qubes securedrop-proxy
func = self._send_http_json_request

return func(method, path_query, body, headers)
return func(method, path_query, body, headers, timeout)

def _send_http_json_request(
self,
method: str,
path_query: str,
body: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Tuple[Any, int, Dict[str, str]]:
url = urljoin(self.server, path_query)
kwargs = {"headers": headers} # type: Dict[str, Any]

if timeout:
kwargs["timeout"] = timeout

if method == "POST":
kwargs["data"] = body

result = requests.request(method, url, **kwargs)
try:
result = requests.request(method, url, **kwargs)
except (ConnectTimeout, ReadTimeout):
raise RequestTimeoutError

# Because when we download a file there is no JSON in the body
if path_query.endswith("/download"):
Expand All @@ -119,6 +159,7 @@ def _send_rpc_json_request(
path_query: str,
body: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Tuple[Any, int, Dict[str, str]]:
data = {"method": method, "path_query": path_query} # type: Dict[str, Any]

Expand All @@ -128,8 +169,8 @@ def _send_rpc_json_request(
if headers is not None and headers:
data["headers"] = headers

data_str = json.dumps(data, sort_keys=True)
result = json.loads(json_query(self.proxy_vm_name, data_str))
data_str = json.dumps(data)
result = json.loads(json_query(self.proxy_vm_name, data_str, timeout))
return json.loads(result["body"]), result["status"], result["headers"]

def authenticate(self, totp: Optional[str] = None) -> bool:
Expand All @@ -153,10 +194,11 @@ def authenticate(self, totp: Optional[str] = None) -> bool:

try:
token_data, status_code, headers = self._send_json_request(
method, path_query, body=body
method, path_query, body=body, timeout=self.default_request_timeout
)
except json.decoder.JSONDecodeError:
raise BaseError("Error in parsing JSON")

if "expiration" not in token_data:
raise AuthError("Authentication error")

Expand Down Expand Up @@ -189,7 +231,10 @@ def get_sources(self) -> List[Source]:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)
except json.decoder.JSONDecodeError:
raise BaseError("Error in parsing JSON")
Expand Down Expand Up @@ -218,7 +263,10 @@ def get_source(self, source: Source) -> Source:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -256,7 +304,10 @@ def delete_source(self, source: Source) -> bool:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -298,7 +349,10 @@ def add_star(self, source: Source) -> bool:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)
if status_code == 404:
raise WrongUUIDError("Missing source {}".format(source.uuid))
Expand All @@ -321,7 +375,10 @@ def remove_star(self, source: Source) -> bool:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)
if status_code == 404:
raise WrongUUIDError("Missing source {}".format(source.uuid))
Expand All @@ -345,7 +402,10 @@ def get_submissions(self, source: Source) -> List[Submission]:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -380,7 +440,10 @@ def get_submission(self, submission: Submission) -> Submission:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -417,7 +480,10 @@ def get_all_submissions(self) -> List[Submission]:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)
except json.decoder.JSONDecodeError:
raise BaseError("Error in parsing JSON")
Expand Down Expand Up @@ -451,7 +517,10 @@ def delete_submission(self, submission: Submission) -> bool:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -481,7 +550,7 @@ def delete_submission_from_string(self, uuid: str, source_uuid: str) -> bool:
return self.delete_submission(s)

def download_submission(
self, submission: Submission, path: str = ""
self, submission: Submission, path: str = "", timeout: Optional[int] = None
) -> Tuple[str, str]:
"""
Returns a tuple of sha256sum and file path for a given Submission object. This method
Expand All @@ -503,7 +572,10 @@ def download_submission(

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=timeout or self.default_download_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -546,7 +618,10 @@ def flag_source(self, source: Source) -> bool:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -577,7 +652,10 @@ def get_current_user(self) -> Any:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

except json.decoder.JSONDecodeError:
Expand Down Expand Up @@ -608,7 +686,11 @@ def reply_source(

try:
data, status_code, headers = self._send_json_request(
method, path_query, body=json.dumps(reply), headers=self.req_headers
method,
path_query,
body=json.dumps(reply),
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 400:
Expand Down Expand Up @@ -637,7 +719,10 @@ def get_replies_from_source(self, source: Source) -> List[Reply]:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -669,7 +754,10 @@ def get_reply_from_source(self, source: Source, reply_uuid: str) -> Reply:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand All @@ -696,7 +784,10 @@ def get_all_replies(self) -> List[Reply]:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

except json.decoder.JSONDecodeError:
Expand Down Expand Up @@ -734,7 +825,10 @@ def download_reply(self, reply: Reply, path: str = "") -> Tuple[str, str]:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -783,7 +877,10 @@ def delete_reply(self, reply: Reply) -> bool:

try:
data, status_code, headers = self._send_json_request(
method, path_query, headers=self.req_headers
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down
2 changes: 1 addition & 1 deletion sdclientapi/sdlocalobjects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

if typing.TYPE_CHECKING:
from typing import Dict
from typing import Dict # noqa: F401


class BaseError(Exception):
Expand Down

0 comments on commit 4227e87

Please sign in to comment.