diff --git a/ipykernel/debugger.py b/ipykernel/debugger.py index 900c10bf0..8acb65cbf 100644 --- a/ipykernel/debugger.py +++ b/ipykernel/debugger.py @@ -10,6 +10,7 @@ from .compiler import (get_file_name, get_tmp_directory, get_tmp_hash_seed) import debugpy +import time class DebugpyMessageQueue: @@ -96,13 +97,19 @@ class DebugpyClient: def __init__(self, log, debugpy_stream, event_callback): self.log = log self.debugpy_stream = debugpy_stream - self.routing_id = None self.event_callback = event_callback self.message_queue = DebugpyMessageQueue(self._forward_event, self.log) + self.debugpy_host = '127.0.0.1' + self.debugpy_port = -1 + self.routing_id = None self.wait_for_attach = True self.init_event = Event() self.init_event_seq = -1 + def _get_endpoint(self): + host, port = self.get_host_port() + return 'tcp://' + host + ':' + str(port) + def _forward_event(self, msg): if msg['event'] == 'initialized': self.init_event.set() @@ -146,6 +153,28 @@ async def _handle_init_sequence(self): attach_rep = await self._wait_for_response() return attach_rep + def get_host_port(self): + if self.debugpy_port == -1: + socket = self.debugpy_stream.socket + socket.bind_to_random_port('tcp://' + self.debugpy_host) + self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode('utf-8') + socket.unbind(self.endpoint) + index = self.endpoint.rfind(':') + self.debugpy_port = self.endpoint[index+1:] + return self.debugpy_host, self.debugpy_port + + + def connect_tcp_socket(self): + self.debugpy_stream.socket.connect(self._get_endpoint()) + self.routing_id = self.debugpy_stream.socket.getsockopt(zmq.ROUTING_ID) + + def disconnect_tcp_socket(self): + self.debugpy_stream.socket.disconnect(self._get_endpoint()) + self.routing_id = None + self.init_event = Event() + self.init_event_seq = -1 + self.wait_for_attach = True + def receive_dap_frame(self, frame): self.message_queue.put_tcp_frame(frame) @@ -194,8 +223,11 @@ def __init__(self, log, debugpy_stream, event_callback, shell_socket, session): self.breakpoint_list = {} self.stopped_threads = [] + self.debugpy_initialized = False + self.debugpy_host = '127.0.0.1' self.debugpy_port = 0 + self.endpoint = None async def _forward_message(self, msg): return await self.debugpy_client.send_dap_request(msg) @@ -205,29 +237,24 @@ def tcp_client(self): return self.debugpy_client def start(self): - socket = self.debugpy_client.debugpy_stream.socket - socket.bind_to_random_port('tcp://' + self.debugpy_host) - endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode('utf-8') - socket.unbind(endpoint) - index = endpoint.rfind(':') - self.debugpy_port = endpoint[index+1:] - code = 'import debugpy;' - code += 'debugpy.listen(("' + self.debugpy_host + '",' + self.debugpy_port + '))' - content = { - 'code': code, - 'silent': True - } - self.session.send(self.shell_socket, 'execute_request', content, - None, (self.shell_socket.getsockopt(zmq.ROUTING_ID))) + if not self.debugpy_initialized: + host, port = self.debugpy_client.get_host_port() + code = 'import debugpy;' + code += 'debugpy.listen(("' + host + '",' + port + '))' + content = { + 'code': code, + 'silent': True + } + self.session.send(self.shell_socket, 'execute_request', content, + None, (self.shell_socket.getsockopt(zmq.ROUTING_ID))) - self.session.recv(self.shell_socket, mode=0) - socket.connect(endpoint) - debugpy.trace_this_thread(False) - return True + ident, msg = self.session.recv(self.shell_socket, mode=0) + self.debugpy_initialized = msg['content']['status'] == 'ok' + self.debugpy_client.connect_tcp_socket() + return self.debugpy_initialized def stop(self): - # TODO - pass + self.debugpy_client.disconnect_tcp_socket() async def dumpCell(self, message): code = message['arguments']['code'] @@ -289,9 +316,10 @@ async def variables(self, message): return reply async def attach(self, message): + host, port = self.debugpy_client.get_host_port() message['arguments']['connect'] = { - 'host': self.debugpy_host, - 'port': self.debugpy_port + 'host': host, + 'port': port } message['arguments']['logToFile'] = True return await self._forward_message(message)