Skip to content

Commit

Permalink
Merge pull request #508 from Textualize/msgpack-devtools
Browse files Browse the repository at this point in the history
Msgpack devtools
  • Loading branch information
willmcgugan authored May 13, 2022
2 parents 96ce420 + d20f129 commit 01ac3dd
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 72 deletions.
46 changes: 45 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ rich = "^12.3.0"
click = "8.1.2"
importlib-metadata = "^4.11.3"
typing-extensions = { version = "^4.0.0", python = "<3.8" }
msgpack = "^1.0.3"

[tool.poetry.dev-dependencies]
pytest = "^6.2.3"
Expand Down
29 changes: 17 additions & 12 deletions src/textual/devtools/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
import base64
import datetime
import inspect
import json
import msgpack
import pickle
from time import time
from asyncio import Queue, Task, QueueFull
from io import StringIO
from typing import Type, Any, NamedTuple
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(self, host: str = "127.0.0.1", port: int = DEVTOOLS_PORT) -> None:
self.update_console_task: Task | None = None
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
self.websocket: ClientWebSocketResponse | None = None
self.log_queue: Queue[str | Type[ClientShutdown]] | None = None
self.log_queue: Queue[str | bytes | Type[ClientShutdown]] | None = None
self.spillover: int = 0

async def connect(self) -> None:
Expand Down Expand Up @@ -144,7 +145,10 @@ async def send_queued_logs():
if log is ClientShutdown:
log_queue.task_done()
break
await websocket.send_str(log)
if isinstance(log, str):
await websocket.send_str(log)
else:
await websocket.send_bytes(log)
log_queue.task_done()

self.log_queue_task = asyncio.create_task(send_queued_logs())
Expand Down Expand Up @@ -203,17 +207,18 @@ def log(self, log: DevtoolsLog) -> None:
segments = self.console.export_segments()

encoded_segments = self._encode_segments(segments)
message = json.dumps(
message: bytes | None = msgpack.packb(
{
"type": "client_log",
"payload": {
"timestamp": int(datetime.datetime.utcnow().timestamp()),
"timestamp": int(time()),
"path": getattr(log.caller, "filename", ""),
"line_number": getattr(log.caller, "lineno", 0),
"encoded_segments": encoded_segments,
"segments": encoded_segments,
},
}
)
assert message is not None
try:
if self.log_queue:
self.log_queue.put_nowait(message)
Expand All @@ -233,15 +238,15 @@ def log(self, log: DevtoolsLog) -> None:
except QueueFull:
self.spillover += 1

def _encode_segments(self, segments: list[Segment]) -> str:
"""Pickle and Base64 encode the list of Segments
@classmethod
def _encode_segments(cls, segments: list[Segment]) -> bytes:
"""Pickle a list of Segments
Args:
segments (list[Segment]): A list of Segments to encode
Returns:
str: The Segment list pickled with pickle protocol v3, then base64 encoded
bytes: The Segment list pickled with the latest protocol.
"""
pickled = pickle.dumps(segments, protocol=3)
encoded = base64.b64encode(pickled)
return str(encoded, encoding="utf-8")
pickled = pickle.dumps(segments, protocol=4)
return pickled
14 changes: 4 additions & 10 deletions src/textual/devtools/renderables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from datetime import datetime, timezone
from datetime import datetime
from pathlib import Path
from typing import Iterable

Expand Down Expand Up @@ -72,19 +72,13 @@ def __init__(
def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
local_time = (
datetime.fromtimestamp(self.unix_timestamp)
.replace(tzinfo=timezone.utc)
.astimezone(tz=datetime.now().astimezone().tzinfo)
)
timezone_name = local_time.tzname()
local_time = datetime.fromtimestamp(self.unix_timestamp)
table = Table.grid(expand=True)
table.add_column()
table.add_column()

file_link = escape(f"file://{Path(self.path).absolute()}")
file_and_line = escape(f"{Path(self.path).name}:{self.line_number}")
table.add_row(
f"[dim]{local_time.time()} {timezone_name}",
f"[dim]{local_time.time()}",
Align.right(
Text(f"{file_and_line}", style=Style(dim=True, link=file_link))
),
Expand Down
35 changes: 20 additions & 15 deletions src/textual/devtools/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aiohttp.web_ws import WebSocketResponse
from rich.console import Console
from rich.markup import escape
import msgpack

from textual.devtools.renderables import (
DevConsoleLog,
Expand Down Expand Up @@ -160,26 +161,25 @@ async def _consume_incoming(self) -> None:
"""
last_message_time: float | None = None
while True:
message_json = await self.incoming_queue.get()
if message_json is None:
message = await self.incoming_queue.get()
if message is None:
self.incoming_queue.task_done()
break

type = message_json["type"]
type = message["type"]
if type == "client_log":
path = message_json["payload"]["path"]
line_number = message_json["payload"]["line_number"]
timestamp = message_json["payload"]["timestamp"]
encoded_segments = message_json["payload"]["encoded_segments"]
decoded_segments = base64.b64decode(encoded_segments)
segments = pickle.loads(decoded_segments)
path = message["payload"]["path"]
line_number = message["payload"]["line_number"]
timestamp = message["payload"]["timestamp"]
encoded_segments = message["payload"]["segments"]
segments = pickle.loads(encoded_segments)
message_time = time()
if (
last_message_time is not None
and message_time - last_message_time > 1
):
# Print a rule if it has been longer than a second since the last message
self.service.console.rule("")
self.service.console.rule()
self.service.console.print(
DevConsoleLog(
segments=segments,
Expand All @@ -190,7 +190,7 @@ async def _consume_incoming(self) -> None:
)
last_message_time = message_time
elif type == "client_spillover":
spillover = int(message_json["payload"]["spillover"])
spillover = int(message["payload"]["spillover"])
info_renderable = DevConsoleNotice(
f"Discarded {spillover} messages", level="warning"
)
Expand Down Expand Up @@ -219,21 +219,26 @@ async def run(self) -> WebSocketResponse:
await self.service.send_server_info(client_handler=self)
async for message in self.websocket:
message = cast(WSMessage, message)
if message.type == WSMsgType.TEXT:

if message.type in (WSMsgType.TEXT, WSMsgType.BINARY):

try:
message_json = json.loads(message.data)
if isinstance(message.data, bytes):
message = msgpack.unpackb(message.data)
else:
message = json.loads(message.data)
except JSONDecodeError:
self.service.console.print(escape(str(message.data)))
continue

type = message_json.get("type")
type = message.get("type")
if not type:
continue
if (
type in QUEUEABLE_TYPES
and not self.service.shutdown_event.is_set()
):
await self.incoming_queue.put(message_json)
await self.incoming_queue.put(message)
elif message.type == WSMsgType.ERROR:
self.service.console.print(
DevConsoleNotice("Websocket error occurred", level="error")
Expand Down
20 changes: 8 additions & 12 deletions tests/devtools/test_devtools.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
from datetime import datetime, timezone
from datetime import datetime

import pytest
import time_machine
from rich.align import Align
from rich.console import Console
from rich.segment import Segment

import msgpack
from tests.utilities.render import wait_for_predicate
from textual.devtools.renderables import DevConsoleLog, DevConsoleNotice

TIMESTAMP = 1649166819
WIDTH = 40
# The string "Hello, world!" is encoded in the payload below
EXAMPLE_LOG = {
_EXAMPLE_LOG = {
"type": "client_log",
"payload": {
"encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ"
"21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=",
"segments": b"\x80\x04\x955\x00\x00\x00\x00\x00\x00\x00]\x94\x8c\x0crich.segment\x94\x8c\x07Segment\x94\x93\x94\x8c\rHello, world!\x94NN\x87\x94\x81\x94a.",
"line_number": 123,
"path": "abc/hello.py",
"timestamp": TIMESTAMP,
},
}
EXAMPLE_LOG = msgpack.packb(_EXAMPLE_LOG)


@pytest.fixture(scope="module")
Expand All @@ -48,15 +49,10 @@ def test_log_message_render(console):
right: Align = right_cells[0]

# Since we can't guarantee the timezone the tests will run in...
local_time = (
datetime.fromtimestamp(TIMESTAMP)
.replace(tzinfo=timezone.utc)
.astimezone(tz=datetime.now().astimezone().tzinfo)
)
timezone_name = local_time.tzname()
local_time = datetime.fromtimestamp(TIMESTAMP)
string_timestamp = local_time.time()

assert left == f"[dim]{string_timestamp} {timezone_name}"
assert left == f"[dim]{string_timestamp}"
assert right.align == "right"
assert "hello.py:123" in right.renderable

Expand All @@ -69,7 +65,7 @@ def test_internal_message_render(console):


async def test_devtools_valid_client_log(devtools):
await devtools.websocket.send_json(EXAMPLE_LOG)
await devtools.websocket.send_bytes(EXAMPLE_LOG)
assert devtools.is_connected


Expand Down
32 changes: 18 additions & 14 deletions tests/devtools/test_devtools_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aiohttp.web_ws import WebSocketResponse
from rich.console import ConsoleDimensions
from rich.panel import Panel
import msgpack

from tests.utilities.render import wait_for_predicate
from textual.devtools.client import DevtoolsClient
Expand All @@ -27,36 +28,39 @@ async def test_devtools_client_is_connected(devtools):
assert devtools.is_connected


@time_machine.travel(datetime.fromtimestamp(TIMESTAMP))
@time_machine.travel(datetime.utcfromtimestamp(TIMESTAMP))
async def test_devtools_log_places_encodes_and_queues_message(devtools):

await devtools._stop_log_queue_processing()
devtools.log(DevtoolsLog("Hello, world!", CALLER))
queued_log = await devtools.log_queue.get()
queued_log_json = json.loads(queued_log)
assert queued_log_json == {
queued_log_data = msgpack.unpackb(queued_log)
print(repr(queued_log_data))
assert queued_log_data == {
"type": "client_log",
"payload": {
"timestamp": TIMESTAMP,
"path": CALLER_PATH,
"line_number": CALLER_LINENO,
"encoded_segments": "gANdcQAoY3JpY2guc2VnbWVudApTZWdtZW50CnEBWA0AAABIZWxsbywgd29ybGQhcQJOTodxA4FxBGgBWAEAAAAKcQVOTodxBoFxB2Uu",
"timestamp": 1649166819,
"path": "a/b/c.py",
"line_number": 123,
"segments": b"\x80\x04\x95B\x00\x00\x00\x00\x00\x00\x00]\x94(\x8c\x0crich.segment\x94\x8c\x07Segment\x94\x93\x94\x8c\rHello, world!\x94NN\x87\x94\x81\x94h\x03\x8c\x01\n\x94NN\x87\x94\x81\x94e.",
},
}


@time_machine.travel(datetime.fromtimestamp(TIMESTAMP))
@time_machine.travel(datetime.utcfromtimestamp(TIMESTAMP))
async def test_devtools_log_places_encodes_and_queues_many_logs_as_string(devtools):
await devtools._stop_log_queue_processing()
devtools.log(DevtoolsLog(("hello", "world"), CALLER))
queued_log = await devtools.log_queue.get()
queued_log_json = json.loads(queued_log)
assert queued_log_json == {
queued_log_data = msgpack.unpackb(queued_log)
print(repr(queued_log_data))
assert queued_log_data == {
"type": "client_log",
"payload": {
"timestamp": TIMESTAMP,
"path": CALLER_PATH,
"line_number": CALLER_LINENO,
"encoded_segments": "gANdcQAoY3JpY2guc2VnbWVudApTZWdtZW50CnEBWAsAAABoZWxsbyB3b3JsZHECTk6HcQOBcQRoAVgBAAAACnEFTk6HcQaBcQdlLg==",
"timestamp": 1649166819,
"path": "a/b/c.py",
"line_number": 123,
"segments": b"\x80\x04\x95@\x00\x00\x00\x00\x00\x00\x00]\x94(\x8c\x0crich.segment\x94\x8c\x07Segment\x94\x93\x94\x8c\x0bhello world\x94NN\x87\x94\x81\x94h\x03\x8c\x01\n\x94NN\x87\x94\x81\x94e.",
},
}

Expand Down
Loading

0 comments on commit 01ac3dd

Please sign in to comment.