From f333dc04d34880d37ba9c99fde1b58c47631e733 Mon Sep 17 00:00:00 2001 From: Raoul Wols Date: Wed, 11 Mar 2020 21:02:58 +0100 Subject: [PATCH] Account for server request ID strings Request IDs can be strings or integers. We, the text editor, always use integers for request IDs. But some servers may send out e.g. GUIDs as request IDs. I decided to type-erase the request ID in various places, because all we're doing in most functions is passing through the received request ID from the server all the way to our Response object, finally serializing it back to JSON again. --- Syntaxes/ServerLog.sublime-syntax | 4 ++-- plugin/core/protocol.py | 2 +- plugin/core/rpc.py | 26 +++++++++++--------------- plugin/core/sessions.py | 4 ++-- plugin/core/windows.py | 2 +- tests/test_rpc.py | 24 ++++++++++++++++++------ 6 files changed, 35 insertions(+), 27 deletions(-) diff --git a/Syntaxes/ServerLog.sublime-syntax b/Syntaxes/ServerLog.sublime-syntax index ffda8e4a2..32464566a 100644 --- a/Syntaxes/ServerLog.sublime-syntax +++ b/Syntaxes/ServerLog.sublime-syntax @@ -41,8 +41,8 @@ contexts: - - match: \( scope: punctuation.section.parens.begin.lsp set: - - match: \d+ - scope: constant.numeric.integer.decimal.lsp + - match: '[^\s)]+' + scope: constant.numeric.lsp set: - match: \) scope: punctuation.section.parens.end.lsp diff --git a/plugin/core/protocol.py b/plugin/core/protocol.py index 9371ea1b1..f10b5fb70 100644 --- a/plugin/core/protocol.py +++ b/plugin/core/protocol.py @@ -194,7 +194,7 @@ class Response: __slots__ = ('request_id', 'result') - def __init__(self, request_id: int, result: Union[None, Mapping[str, Any], Iterable[Any]]) -> None: + def __init__(self, request_id: Any, result: Union[None, Mapping[str, Any], Iterable[Any]]) -> None: self.request_id = request_id self.result = result diff --git a/plugin/core/rpc.py b/plugin/core/rpc.py index e4f6f70d7..2b4506fc4 100644 --- a/plugin/core/rpc.py +++ b/plugin/core/rpc.py @@ -2,7 +2,7 @@ from .protocol import Request, Notification, Response from .transports import StdioTransport, Transport from .types import Settings -from .typing import Any, Dict, Tuple, Callable, Optional, Union, Mapping +from .typing import Any, Dict, Tuple, Callable, Optional, Mapping from threading import Condition from threading import Lock import subprocess @@ -42,16 +42,16 @@ def log(self, message: str, params: Any, log_payload: bool) -> None: message = "{}: {}".format(message, params) self.sink(message) - def format_response(self, direction: str, request_id: int) -> str: + def format_response(self, direction: str, request_id: Any) -> str: return "{} {} {}".format(direction, self.server_name, request_id) - def format_request(self, direction: str, method: str, request_id: int) -> str: + def format_request(self, direction: str, method: str, request_id: Any) -> str: return "{} {} {}({})".format(direction, self.server_name, method, request_id) def format_notification(self, direction: str, method: str) -> str: return "{} {} {}".format(direction, self.server_name, method) - def outgoing_response(self, request_id: int, params: Any) -> None: + def outgoing_response(self, request_id: Any, params: Any) -> None: if not self.settings.log_debug: return self.log(self.format_response(Direction.Outgoing, request_id), params, self.settings.log_payloads) @@ -79,7 +79,7 @@ def incoming_response(self, request_id: int, params: Any) -> None: return self.log(self.format_response(Direction.Incoming, request_id), params, self.settings.log_payloads) - def incoming_request(self, request_id: int, method: str, params: Any, unhandled: bool) -> None: + def incoming_request(self, request_id: Any, method: str, params: Any, unhandled: bool) -> None: if not self.settings.log_debug: return direction = "unhandled" if unhandled else Direction.Incoming @@ -96,7 +96,7 @@ class Client(object): def __init__(self, transport: Transport, settings: Settings) -> None: self.transport = transport # type: Optional[Transport] self.transport.start(self.receive_payload, self.on_transport_closed) - self.request_id = 0 + self.request_id = 0 # Our request IDs are always integers. self.logger = PreformattedPayloadLogger(settings, "server", debug) self._response_handlers = {} # type: Dict[int, Tuple[Optional[Callable], Optional[Callable[[Any], None]]]] self._request_handlers = {} # type: Dict[str, Callable] @@ -245,20 +245,16 @@ def on_notification(self, notification_method: str, handler: Callable) -> None: def request_or_notification_handler(self, payload: Mapping[str, Any]) -> None: method = payload["method"] # type: str params = payload.get("params") - request_id = payload.get("id") # type: Union[str, int, None] + # Server request IDs can be either a string or an int. + request_id = payload.get("id") if request_id is not None: - request_id_int = int(request_id) - - def log(method: str, params: Any, unhandled: bool) -> None: - nonlocal request_id_int - self.logger.incoming_request(request_id_int, method, params, unhandled) - - self.handle(request_id_int, method, params, "request", self._request_handlers, log) + self.handle(request_id, method, params, "request", self._request_handlers, + lambda a, b, c: self.logger.incoming_request(request_id, a, b, c)) else: self.handle(None, method, params, "notification", self._notification_handlers, self.logger.incoming_notification) - def handle(self, request_id: Optional[int], method: str, params: Any, typestr: str, + def handle(self, request_id: Any, method: str, params: Any, typestr: str, handlers: Mapping[str, Callable], log: Callable[[str, Any, bool], None]) -> None: handler = handlers.get(method) log(method, params, handler is None) diff --git a/plugin/core/sessions.py b/plugin/core/sessions.py index 91e03aee6..968e4f806 100644 --- a/plugin/core/sessions.py +++ b/plugin/core/sessions.py @@ -272,10 +272,10 @@ def _handle_initialize_result(self, result: Any) -> None: if self._on_post_initialize: self._on_post_initialize(self, None) - def _handle_request_workspace_folders(self, _: Any, request_id: int) -> None: + def _handle_request_workspace_folders(self, _: Any, request_id: Any) -> None: self.client.send_response(Response(request_id, [wf.to_lsp() for wf in self._workspace_folders])) - def _handle_request_workspace_configuration(self, params: Dict[str, Any], request_id: int) -> None: + def _handle_request_workspace_configuration(self, params: Dict[str, Any], request_id: Any) -> None: items = [] # type: List[Any] requested_items = params.get("items") or [] for requested_item in requested_items: diff --git a/plugin/core/windows.py b/plugin/core/windows.py index 1e3b2f42c..eabe0e89c 100644 --- a/plugin/core/windows.py +++ b/plugin/core/windows.py @@ -484,7 +484,7 @@ def _start_client(self, config: ClientConfig, file_path: str) -> None: debug("window {} added session {}".format(self._window.id(), config.name)) self._sessions.setdefault(config.name, []).append(session) - def _handle_message_request(self, params: dict, client: Client, request_id: int) -> None: + def _handle_message_request(self, params: dict, client: Client, request_id: Any) -> None: actions = params.get("actions", []) titles = list(action.get("title") for action in actions) diff --git a/tests/test_rpc.py b/tests/test_rpc.py index dc5e8fc2f..c86507bf9 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -5,14 +5,10 @@ from LSP.plugin.core.rpc import format_request from LSP.plugin.core.transports import Transport from LSP.plugin.core.types import Settings +from LSP.plugin.core.typing import List, Tuple, Dict, Any from test_mocks import MockSettings import json import unittest -try: - from typing import Any, List, Dict, Tuple, Callable, Optional - assert Any and List and Dict and Tuple and Callable and Optional -except ImportError: - pass def return_empty_dict_result(message): @@ -143,14 +139,30 @@ def test_server_request(self): client = Client(transport, settings) self.assertIsNotNone(client) self.assertTrue(transport.has_started) - pings = [] # type: List[Tuple[int, Dict[str, Any]]] + pings = [] # type: List[Tuple[Any, Dict[str, Any]]] client.on_request( "ping", lambda params, request_id: pings.append((request_id, params))) transport.receive('{ "id": 42, "method": "ping"}') self.assertEqual(len(pings), 1) + self.assertIsInstance(pings[0][0], int) self.assertEqual(pings[0][0], 42) + def test_server_request_non_integer_request(self): + transport = MockTransport() + settings = MockSettings() + client = Client(transport, settings) + self.assertIsNotNone(client) + self.assertTrue(transport.has_started) + pings = [] # type: List[Tuple[Any, Dict[str, Any]]] + client.on_request( + "ping", + lambda params, request_id: pings.append((request_id, params))) + transport.receive('{ "id": "abcd-1234-efgh-5678", "method": "ping"}') + self.assertEqual(len(pings), 1) + self.assertIsInstance(pings[0][0], str) + self.assertEqual(pings[0][0], "abcd-1234-efgh-5678") + def test_error_response_handler(self): transport = MockTransport(return_error) settings = MockSettings()