Skip to content

Commit

Permalink
Merge pull request #21 from scivisum/task/RD-36216_why_ws_connections…
Browse files Browse the repository at this point in the history
…_closed

Task/rd 36216 why ws connections closed
  • Loading branch information
Erustus Agutu authored Sep 20, 2019
2 parents 7b7761f + 7548b8c commit 421cc66
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 244 deletions.
77 changes: 16 additions & 61 deletions browserdebuggertools/chrome/interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import contextlib
import time
import logging
from base64 import b64decode, b64encode

from browserdebuggertools.sockethandler import SocketHandler
from browserdebuggertools.exceptions import (
DevToolsTimeoutException, ResultNotFoundError, DomainNotFoundError
)


logging.basicConfig(format='%(levelname)s:%(message)s')

Expand All @@ -28,45 +25,14 @@ def __init__(self, port, timeout=30, domains=None):
:param port: remote-debugging-port to connect.
:param timeout: Timeout between executing a command and receiving a result.
:param domains: List of domains to be enabled. By default Page, Network and Runtime are
automatically enabled.
:param domains: Dictionary of dictionaries where the Key is the domain string and the Value
is a dictionary of the arguments passed with the domain upon enabling.
"""
self.timeout = timeout
self._socket_handler = SocketHandler(port)

if domains:
for domain in domains:
self.enable_domain(domain)
self._socket_handler = SocketHandler(port, timeout, domains=domains)

def quit(self):
self._socket_handler.close()

def wait_for_result(self, result_id):
""" Waits for a result to complete within the timeout duration then returns it.
Raises a DevToolsTimeoutException if it cannot find the result.
:param result_id: The result id.
:return: The result.
"""
start = time.time()
while not self.timeout or (time.time() - start) < self.timeout:
try:
return self._socket_handler.find_result(result_id)
except ResultNotFoundError:
time.sleep(0.5)
raise DevToolsTimeoutException(
"Reached timeout limit of {}, waiting for a response message".format(self.timeout)
)

def get_result(self, result_id):
""" Gets the result for a given id, if it has finished executing
Raises a ResultNotFoundError if it cannot find the result.
:param result_id: The result id.
:return: The result.
"""
return self._socket_handler.find_result(result_id)

def get_events(self, domain, clear=False):
""" Retrieves all events for a given domain
:param domain: The domain to get the events for.
Expand All @@ -75,7 +41,7 @@ def get_events(self, domain, clear=False):
"""
return self._socket_handler.get_events(domain, clear)

def execute(self, domain, method, args=None):
def execute(self, domain, method, params=None):
""" Executes a command and returns the result.
Usage example:
Expand All @@ -86,44 +52,33 @@ def execute(self, domain, method, args=None):
:param domain: Chrome DevTools Protocol Domain
:param method: Domain specific method.
:param args: Parameters to be executed
:param params: Parameters to be executed
:return: The result of the command
"""
result_id = self._socket_handler.execute("{}.{}".format(domain, method), args)

return self.wait_for_result(result_id)
return self._socket_handler.execute(domain, method, params=params)

def enable_domain(self, domain):
def enable_domain(self, domain, params=None):
""" Enables notifications for the given domain.
"""
self._socket_handler.add_domain(domain)
result = self.execute(domain, "enable")
if "error" in result:
self._socket_handler.remove_domain(domain)
raise DomainNotFoundError("Domain \"{}\" not found.".format(domain))

logging.info("\"{}\" domain has been enabled".format(domain))
self._socket_handler.enable_domain(domain, parameters=params)

def disable_domain(self, domain):
""" Disables further notifications from the given domain.
""" Disables further notifications from the given domain. Also clears any events cached for
that domain, it is recommended that you get events for the domain before disabling it.
"""
self._socket_handler.remove_domain(domain)
result = self.execute(domain, "disable")
if "error" in result:
logging.warn("Domain \"{}\" doesn't exist".format(domain))
else:
logging.info("Domain {} has been disabled".format(domain))
self._socket_handler.disable_domain(domain)

@contextlib.contextmanager
def set_timeout(self, value):
""" Switches the timeout to the given value.
"""
_timeout = self.timeout
self.timeout = value
_timeout = self._socket_handler.timeout
self._socket_handler.timeout = value
try:
yield
finally:
self.timeout = _timeout
self._socket_handler.timeout = _timeout

def navigate(self, url):
""" Navigates to the given url asynchronously
Expand Down
205 changes: 164 additions & 41 deletions browserdebuggertools/sockethandler.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,106 @@
import json
import logging
import socket
import time
from datetime import datetime

import requests
import websocket

from browserdebuggertools.exceptions import ResultNotFoundError, TabNotFoundError, \
DomainNotEnabledError
from browserdebuggertools.exceptions import (
ResultNotFoundError, TabNotFoundError,
DomainNotEnabledError, DevToolsTimeoutException, DomainNotFoundError
)

logging.basicConfig(format='%(levelname)s:%(message)s')


def open_connection_if_closed(socket_handler_method):

def retry_if_exception(socket_handler_instance, *args, **kwargs):

try:
return socket_handler_method(socket_handler_instance, *args, **kwargs)

except websocket.WebSocketConnectionClosedException:

socket_handler_instance.increment_connection_closed_count()
retry_if_exception(socket_handler_instance, *args, **kwargs)

return retry_if_exception


class SocketHandler(object):

CONN_TIMEOUT = 15 # Connection timeout
MAX_CONNECTION_RETRIES = 3
RETRY_COUNT_TIMEOUT = 300 # Seconds
CONN_TIMEOUT = 15 # Connection timeout seconds

def __init__(self, port):
websocket_url = self._get_websocket_url(port)
self.websocket = websocket.create_connection(websocket_url, timeout=self.CONN_TIMEOUT)
self.websocket.settimeout(0) # Don"t wait for new messages
def __init__(self, port, timeout, domains=None):

self.timeout = timeout

if not domains:
domains = {}

self._domains = domains
self._events = dict([(k, []) for k in self._domains])
self._results = {}

self._next_result_id = 0
self.domains = set()
self.results = {}
self.events = {}
self._connection_last_closed = None
self._connection_closed_count = 0

self._websocket_url = self._get_websocket_url(port)
self._websocket = self._setup_websocket()

def __del__(self):
try:
self.close()
except:
pass

def _setup_websocket(self):

self._websocket = websocket.create_connection(
self._websocket_url, timeout=self.CONN_TIMEOUT
)
self._websocket.settimeout(0) # Don"t wait for new messages

for domain, params in self._domains.items():
self.enable_domain(domain, params)

return self._websocket

def increment_connection_closed_count(self):

now = datetime.now()

if (
self._connection_last_closed and
(now - self._connection_last_closed).seconds > self.RETRY_COUNT_TIMEOUT
):
self._connection_closed_count = 0

self._connection_last_closed = now
self._connection_closed_count += 1

if self._connection_closed_count > self.MAX_CONNECTION_RETRIES:
raise Exception("Websocket connection found closed too many times")

self._setup_websocket()

@open_connection_if_closed
def _send(self, data):
data['id'] = self._next_result_id
self._websocket.send(json.dumps(data, sort_keys=True))

@open_connection_if_closed
def _recv(self):
message = self._websocket.recv()
if message:
message = json.loads(message)
return message

def _get_websocket_url(self, port):
targets = requests.get(
Expand All @@ -36,67 +113,113 @@ def _get_websocket_url(self, port):
return tabs[0]["webSocketDebuggerUrl"]

def close(self):
self.websocket.close()
if hasattr(self, "_websocket"):
self._websocket.close()

def _append(self, message):

if "result" in message:
self.results[message["id"]] = message.get("result")
self._results[message["id"]] = message.get("result")
elif "error" in message:
result_id = message.pop("id")
self.results[result_id] = message
self._results[result_id] = message
elif "method" in message:
domain, event = message["method"].split(".")
self.events[domain].append(message)
self._events[domain].append(message)
else:
logging.warning("Unrecognised message: {}".format(message))

def flush_messages(self):
def _flush_messages(self):
""" Will only return once all the messages have been retrieved.
and will hold the thread until so.
"""
try:
message = self.websocket.recv()
message = self._recv()
while message:
message = json.loads(message)
self._append(message)
message = self.websocket.recv()
message = self._recv()
except socket.error:
return

def find_result(self, result_id):
if result_id not in self.results:
self.flush_messages()
def _find_next_result(self):
if self._next_result_id not in self._results:
self._flush_messages()

if result_id not in self.results:
raise ResultNotFoundError("Result not found for id: {} .".format(result_id))
if self._next_result_id not in self._results:
raise ResultNotFoundError("Result not found for id: {} .".format(self._next_result_id))

return self.results.pop(result_id)
return self._results.pop(self._next_result_id)

def execute(self, method, params):
self._next_result_id += 1
self.websocket.send(json.dumps({
"id": self._next_result_id, "method": method, "params": params if params else {}
}, sort_keys=True))
return self._next_result_id
def execute(self, domainName, methodName, params=None):

def add_domain(self, domain):
if domain not in self.domains:
self.domains.add(domain)
self.events[domain] = []
if params is None:
params = {}

def remove_domain(self, domain):
if domain in self.domains:
self.domains.remove(domain)
self._next_result_id += 1
method = "{}.{}".format(domainName, methodName)
self._send({
"method": method, "params": params
})
return self._wait_for_result()

def _add_domain(self, domain, params):
if domain not in self._domains:
self._domains[domain] = params
self._events[domain] = []

def _remove_domain(self, domain):
if domain in self._domains:
del self._domains[domain]
del self._events[domain]

def get_events(self, domain, clear=False):
if domain not in self.domains:
if domain not in self._domains:
raise DomainNotEnabledError(
'The domain "%s" is not enabled, try enabling it via the interface.' % domain
)

self.flush_messages()
events = self.events[domain][:]
self._flush_messages()
events = self._events[domain][:]
if clear:
self.events[domain] = []
self._events[domain] = []

return events

def _wait_for_result(self):
""" Waits for a result to complete within the timeout duration then returns it.
Raises a DevToolsTimeoutException if it cannot find the result.
:return: The result.
"""
start = time.time()
while not self.timeout or (time.time() - start) < self.timeout:
try:
return self._find_next_result()
except ResultNotFoundError:
time.sleep(0.5)
raise DevToolsTimeoutException(
"Reached timeout limit of {}, waiting for a response message".format(self.timeout)
)

def enable_domain(self, domainName, parameters=None):

if not parameters:
parameters = {}

self._add_domain(domainName, parameters)
result = self.execute(domainName, "enable", parameters)
if "error" in result:
self._remove_domain(domainName)
raise DomainNotFoundError("Domain \"{}\" not found.".format(domainName))

logging.info("\"{}\" domain has been enabled".format(domainName))

def disable_domain(self, domainName):
""" Disables further notifications from the given domain.
"""
self._remove_domain(domainName)
result = self.execute(domainName, "disable", {})
if "error" in result:
logging.warn("Domain \"{}\" doesn't exist".format(domainName))
else:
logging.info("Domain {} has been disabled".format(domainName))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name="browserdebuggertools",
version="4.0.0",
version="5.0.0",
packages=PACKAGES,
install_requires=requires,
license="GNU General Public License v3",
Expand Down
Loading

0 comments on commit 421cc66

Please sign in to comment.