Skip to content
This repository has been archived by the owner on Aug 2, 2023. It is now read-only.

Commit

Permalink
Separate binding from connecting.
Browse files Browse the repository at this point in the history
  • Loading branch information
ericsnowcurrently committed Feb 20, 2018
1 parent d275a9d commit 8062503
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 83 deletions.
59 changes: 35 additions & 24 deletions tests/helpers/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,33 @@ def parse_each(self, messages):
class Started(object):
"""A simple wrapper around a started message protocol daemon."""

def __init__(self, fake):
def __init__(self, fake, address, starting=None):
self.fake = fake
self.address = address
self._starting = starting

def __enter__(self):
self.wait_until_connected()
return self

def __exit__(self, *args):
self.close()

def wait_until_connected(self, timeout=None):
starting = self._starting
if starting is None:
return
starting.join(timeout=timeout)
if starting.is_alive():
raise RuntimeError('timed out')
self._starting = None

def send_message(self, msg):
return self.fake.send_response(msg)
self.wait_until_connected()
return self.fake.send_message(msg)

def close(self):
self.wait_until_connected()
self.fake.close()


Expand All @@ -69,8 +83,8 @@ def validate_message(cls, msg):
"""Ensure the message is legitimate."""
# By default check nothing.

def __init__(self, connect, protocol, handler):
self._connect = connect
def __init__(self, bind, protocol, handler):
self._bind = bind
self._protocol = protocol

self._closed = False
Expand All @@ -81,10 +95,8 @@ def __init__(self, connect, protocol, handler):
self._default_handler = handler

# These are set when we start.
self._host = None
self._port = None
self._address = None
self._sock = None
self._server = None
self._listener = None

@property
Expand All @@ -103,18 +115,17 @@ def failures(self):
"""All send/recv failures thus far."""
return list(self._failures)

def start(self, host, port):
def start(self, address):
"""Start the fake daemon.
This calls the earlier provided connect() function.
This calls the earlier provided bind() function.
A listener loop is started in another thread to handle incoming
messages from the socket.
"""
self._host = host or None
self._port = port
self._start()
return self.STARTED(self)
self._address = address
addr, starting = self._start(address)
return self.STARTED(self, addr, starting)

def send_message(self, msg):
"""Serialize msg to the line format and send it to the socket."""
Expand Down Expand Up @@ -150,15 +161,17 @@ def reset(self, force=False):

# internal methods

def _start(self, host=None):
self._sock, self._server = self._connect(
host or self._host,
self._port,
)
def _start(self, address):
connect, addr = self._bind(address)

# TODO: make it a daemon thread?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
def run():
self._sock = connect()
# TODO: make it a daemon thread?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
t = threading.Thread(target=run)
t.start()
return addr, t

def _listen(self):
try:
Expand Down Expand Up @@ -210,6 +223,7 @@ def _send_message(self, msg):
try:
self._send(raw)
except Exception as exc:
raise
failure = StreamFailure('send', msg, exc)
self._failures.append(failure)

Expand All @@ -222,9 +236,6 @@ def _close(self):
if self._sock is not None:
socket.close(self._sock)
self._sock = None
if self._server is not None:
socket.close(self._server)
self._server = None
if self._listener is not None:
self._listener.join(timeout=1)
# TODO: the listener isn't stopping!
Expand Down
21 changes: 13 additions & 8 deletions tests/helpers/pydevd/_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
CMD_VERSION,
)

from ptvsd.wrapper import start_server, start_client
import ptvsd.wrapper as _ptvsd
from ._pydevd import parse_message, encode_message, iter_messages, Message
from tests.helpers import protocol
from tests.helpers import protocol, socket


PROTOCOL = protocol.MessageProtocol(
Expand All @@ -14,19 +14,24 @@
)


def _connect(host, port):
if host is None:
return start_server(port), None
else:
return start_client(host, port), None
def _bind(address):
connect, remote = socket.bind(address)

def connect(_connect=connect):
client, server = _connect()
pydevd, _, _ = _ptvsd._start(client, server)
return socket.Connection(pydevd, server)
return connect, remote


class Started(protocol.Started):

def send_response(self, msg):
self.wait_until_connected()
return self.fake.send_response(msg)

def send_event(self, msg):
self.wait_until_connected()
return self.fake.send_event(msg)


Expand Down Expand Up @@ -92,7 +97,7 @@ def _get_response(cls, req):

def __init__(self, handler=None):
super(FakePyDevd, self).__init__(
_connect,
_bind,
PROTOCOL,
(lambda msg, send: self.handle_request(msg, send, handler)),
)
Expand Down
111 changes: 90 additions & 21 deletions tests/helpers/socket.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,107 @@
from __future__ import absolute_import

from collections import namedtuple
import socket

import ptvsd.wrapper as _ptvsd

def connect(host, port):
"""Return (client, server) after connecting.

If host is None then it's a server, so it will wait for a connection
on localhost. Otherwise it will connect to the remote host.
def create_server(address):
"""Return a server socket after binding."""
host, port = address
return _ptvsd._create_server(port)


def create_client():
"""Return a new (unconnected) client socket."""
return _ptvsd._create_client()


def connect(sock, address):
"""Return a client socket after connecting.
If address is None then it's a server, so it will wait for a
connection. Otherwise it will connect to the remote host.
"""
return _connect(sock, address)


def bind(address):
"""Return (connect, remote addr) for the given address.
"connect" is a function with no args that returns (client, server),
which are sockets. If the host is None then a server socket will
be created bound to localhost, and that server socket will be
returned from connect(). Otherwise a client socket is connected to
the remote address and None is returned from connect() for the
server.
"""
sock = socket.socket(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
)
sock.setsockopt(
socket.SOL_SOCKET,
socket.SO_REUSEADDR,
1,
)
host, _ = address
if host is None:
addr = ('127.0.0.1', port)
sock = create_server(address)
server = sock
server.bind(addr)
server.listen(1)
sock, _ = server.accept()
connect_to = None
remote = sock.getsockname()
else:
addr = (host, port)
sock.connect(addr)
sock = create_client()
server = None
return sock, server
connect_to = address
remote = address

def connect():
client = _connect(sock, connect_to)
return client, server
return connect, remote


def close(sock):
"""Shutdown and close the socket."""
sock.shutdown(socket.SHUT_RDWR)
sock.close()


class Connection(namedtuple('Connection', 'client server')):
"""A wrapper around a client socket.
If a server socket is provided then it will be closed when the
client is closed.
"""

def __new__(cls, client, server=None):
self = super(Connection, cls).__new__(
cls,
client,
server,
)
return self

def send(self, *args, **kwargs):
return self.client.send(*args, **kwargs)

def recv(self, *args, **kwargs):
return self.client.recv(*args, **kwargs)

def makefile(self, *args, **kwargs):
return self.client.makefile(*args, **kwargs)

def shutdown(self, *args, **kwargs):
if self.server is not None:
self.server.shutdown(*args, **kwargs)
self.client.shutdown(*args, **kwargs)

def close(self):
if self.server is not None:
self.server.close()
self.client.close()


########################
# internal functions

def _connect(sock, address):
if address is None:
client, _ = sock.accept()
else:
sock.connect(address)
client = sock
return client
52 changes: 30 additions & 22 deletions tests/helpers/vsc/_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@
)


def _bind(address):
connect, remote = socket.bind(address)

def connect(_connect=connect):
client, server = _connect()
return socket.Connection(client, server)
return connect, remote


class Started(protocol.Started):

def send_request(self, msg):
self.wait_until_connected()
return self.fake.send_request(msg)


Expand Down Expand Up @@ -52,19 +62,23 @@ class FakeVSC(protocol.Daemon):
PROTOCOL = PROTOCOL

def __init__(self, start_adapter, handler=None):
super(FakeVSC, self).__init__(socket.connect, PROTOCOL, handler)

def start_adapter(host, port, _start_adapter=start_adapter):
self._adapter = _start_adapter(host, port)

super(FakeVSC, self).__init__(
_bind,
PROTOCOL,
handler,
)

def start_adapter(address, start=start_adapter):
self._adapter = start(address)
return self._adapter
self._start_adapter = start_adapter
self._adapter = None

def start(self, host, port):
def start(self, address):
"""Start the fake and the adapter."""
if self._adapter is not None:
raise RuntimeError('already started')
return super(FakeVSC, self).start(host, port)
return super(FakeVSC, self).start(address)

def send_request(self, req):
"""Send the given Request object."""
Expand Down Expand Up @@ -94,26 +108,20 @@ def match(msg):

# internal methods

def _start(self, host=None):
start_adapter = (lambda: self._start_adapter(self._host, self._port))
if not self._host:
def _start(self, address):
host, port = address
if host is None:
# The adapter is the server so start it first.
t = threading.Thread(target=start_adapter)
t.start()
super(FakeVSC, self)._start('127.0.0.1')
t.join(timeout=1)
if t.is_alive():
raise RuntimeError('timed out')
adapter = self._start_adapter((None, port))
return super(FakeVSC, self)._start(adapter.address)
else:
# The adapter is the client so start it last.
# TODO: For now don't use this.
raise NotImplementedError
t = threading.Thread(target=super(FakeVSC, self)._start)
t.start()
start_adapter()
t.join(timeout=1)
if t.is_alive():
raise RuntimeError('timed out')
addr, starting = super(FakeVSC, self)._start(address)
self._start_adapter(addr)
# TODO Wait for adapter to be ready?
return addr, starting

def _close(self):
if self._adapter is not None:
Expand Down
Loading

0 comments on commit 8062503

Please sign in to comment.