Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix message parsing for documents containing Content-Length keyword #80

Merged
merged 9 commits into from
Sep 3, 2019
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
- [Max O'Cull](https://github.com/Maxattax97)
- [Tomoya Tanjo](https://github.com/tom-tan)
- [yorodm](https://github.com/yorodm)
- [Denis Loginov](https://github.com/dinvlad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this alphabetical?

43 changes: 34 additions & 9 deletions pygls/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ class JsonRPCProtocol(asyncio.Protocol):

This class provides bidirectional communication which is needed for LSP.
"""
BODY_PATTERN = re.compile(rb'\{.+?\}.*')

CANCEL_REQUEST = '$/cancelRequest'

Expand All @@ -188,6 +187,7 @@ def __init__(self, server):

self.fm = FeatureManager(server)
self.transport = None
self._message_buf = []

def __call__(self):
return self
Expand Down Expand Up @@ -397,18 +397,43 @@ def connection_made(self, transport: asyncio.Transport):
"""Method from base class, called when connection is established"""
self.transport = transport

MESSAGE_PATTERN = re.compile(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this up where other class constants are (alphabetically)?

rb'^(?:[^\r\n]+\r\n)*' +
rb'Content-Length: (?P<length>\d+)\r\n' +
rb'(?:[^\r\n]+\r\n)*\r\n' +
rb'(?P<body>{.*)',
re.DOTALL,
)

def data_received(self, data: bytes):
"""Method from base class, called when server receives the data"""
logger.debug('Received {}'.format(data))

for part in data.split(b'Content-Length'):
try:
body = JsonRPCProtocol.BODY_PATTERN.findall(part)[0]
self._procedure_handler(
json.loads(body.decode(self.CHARSET),
object_hook=deserialize_message))
except IndexError:
pass
while len(data):
# Append the incoming chunk to the message buffer
self._message_buf.append(data)

# Look for the body of the message
message = b''.join(self._message_buf)
found = JsonRPCProtocol.MESSAGE_PATTERN.fullmatch(message)

body = found.group('body') if found else b''
length = int(found.group('length')) if found else 1

if len(body) < length:
# Message is incomplete; bail until more data arrives
return

# Message is complete;
# extract the body and any remaining data,
# and reset the buffer for the next message
body, data = body[:length], body[length:]
self._message_buf = []

# Parse the body
self._procedure_handler(
json.loads(body.decode(self.CHARSET),
object_hook=deserialize_message))

def notify(self, method: str, params=None):
"""Sends a JSON RPC notification to the client."""
Expand Down
55 changes: 30 additions & 25 deletions pygls/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
############################################################################
import asyncio
import logging
import re
import sys
from concurrent.futures import Future, ThreadPoolExecutor
from multiprocessing.pool import ThreadPool
from re import findall
from threading import Event
from typing import Callable, Dict, List

Expand All @@ -37,33 +37,38 @@

async def aio_readline(loop, executor, stop_event, rfile, proxy):
"""Reads data from stdin in separate thread (asynchronously)."""
while not stop_event.is_set():
# Read line
line = await loop.run_in_executor(executor, rfile.readline)

if not line:
continue

# Extract content length from line
try:
content_length = int(findall(rb'\b\d+\b', line)[0])
logger.debug('Content length: {}'.format(content_length))
except IndexError:
continue

# Throw away empty lines
while line and line.strip():
line = await loop.run_in_executor(executor, rfile.readline)

if not line:
continue
CONTENT_LENGTH_PATTERN = re.compile(rb'^Content-Length: (\d+)\r\n$')

# Read body
body = await loop.run_in_executor(executor, rfile.read, content_length)
# Initialize message buffer
message = []
content_length = 0

# Pass body to language server protocol
if body:
proxy(body)
while not stop_event.is_set():
# Read a header line
header = await loop.run_in_executor(executor, rfile.readline)
message.append(header)

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug('Content length: {}'.format(content_length))

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():

# Read body
body = await loop.run_in_executor(executor, rfile.read, content_length)
message.append(body)

# Pass message to language server protocol
proxy(b''.join(message))

# Reset the buffer
message = []
content_length = 0


class StdOutTransportAdapter:
Expand Down
69 changes: 69 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,75 @@ def test_deserialize_message_should_return_request_message():
assert result.params == "1"


def test_data_received_without_content_type_should_handle_message(client_server):
_, server = client_server
body = json.dumps({
"jsonrpc": "2.0",
"method": "test",
"params": 1,
})
message = '\r\n'.join((
'Content-Length: ' + str(len(body)),
'',
body,
))
data = bytes(message, 'utf-8')
server.lsp.data_received(data)


def test_data_received_content_type_first_should_handle_message(client_server):
_, server = client_server
body = json.dumps({
"jsonrpc": "2.0",
"method": "test",
"params": 1,
})
message = '\r\n'.join((
'Content-Type: application/vscode-jsonrpc; charset=utf-8',
'Content-Length: ' + str(len(body)),
'',
body,
))
data = bytes(message, 'utf-8')
server.lsp.data_received(data)


def dummy_message(param=1):
body = json.dumps({
"jsonrpc": "2.0",
"method": "test",
"params": param,
})
message = '\r\n'.join((
'Content-Length: ' + str(len(body)),
'Content-Type: application/vscode-jsonrpc; charset=utf-8',
'',
body,
))
return bytes(message, 'utf-8')


def test_data_received_single_message_should_handle_message(client_server):
_, server = client_server
data = dummy_message()
server.lsp.data_received(data)


def test_data_received_partial_message_should_handle_message(client_server):
_, server = client_server
data = dummy_message()
partial = len(data) - 5
server.lsp.data_received(data[:partial])
server.lsp.data_received(data[partial:])


def test_data_received_multi_message_should_handle_messages(client_server):
_, server = client_server
messages = (dummy_message(i) for i in range(3))
data = b''.join(messages)
server.lsp.data_received(data)


def test_initialize_without_capabilities_should_raise_error(client_server):
_, server = client_server
params = dictToObj({
Expand Down