From 9ab62de9eccc3ae222b601ac066cc515e150d219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20=C4=8Cerm=C3=A1k?= Date: Wed, 9 Oct 2024 16:38:50 +0200 Subject: [PATCH] Refactor journal_logs_reader to always return the cursor --- supervisor/api/host.py | 16 ++++++------ supervisor/utils/systemd_journal.py | 18 ++++--------- tests/api/test_host.py | 12 ++++----- tests/host/test_logs.py | 12 +++++++-- tests/utils/test_systemd_journal.py | 39 ++++++++++++++--------------- 5 files changed, 47 insertions(+), 50 deletions(-) diff --git a/supervisor/api/host.py b/supervisor/api/host.py index 33b5a5bda34..b6455848c50 100644 --- a/supervisor/api/host.py +++ b/supervisor/api/host.py @@ -240,15 +240,13 @@ async def advanced_logs_handler( try: response = web.StreamResponse() response.content_type = CONTENT_TYPE_TEXT - - async def finish_prepare(cursor: str): - if cursor: - response.headers["X-First-Cursor"] = cursor - await response.prepare(request) - - async for line in journal_logs_reader( - resp, log_formatter, finish_prepare - ): + headers_returned = False + async for cursor, line in journal_logs_reader(resp, log_formatter): + if not headers_returned: + if cursor: + response.headers["X-First-Cursor"] = cursor + await response.prepare(request) + headers_returned = True await response.write(line.encode("utf-8") + b"\n") except ConnectionResetError as ex: raise APIError( diff --git a/supervisor/utils/systemd_journal.py b/supervisor/utils/systemd_journal.py index af87ecaa97a..5898d811552 100644 --- a/supervisor/utils/systemd_journal.py +++ b/supervisor/utils/systemd_journal.py @@ -1,6 +1,6 @@ """Utilities for working with systemd journal export format.""" -from collections.abc import AsyncGenerator, Callable, Coroutine +from collections.abc import AsyncGenerator from datetime import UTC, datetime from functools import wraps @@ -60,14 +60,11 @@ def journal_verbose_formatter(entries: dict[str, str]) -> str: async def journal_logs_reader( - journal_logs: ClientResponse, - log_formatter: LogFormatter = LogFormatter.PLAIN, - first_cursor_callback: Callable[[str], Coroutine] | None = None, -) -> AsyncGenerator[str, None]: + journal_logs: ClientResponse, log_formatter: LogFormatter = LogFormatter.PLAIN +) -> AsyncGenerator[(str | None, str), None]: """Read logs from systemd journal line by line, formatted using the given formatter. - Optionally takes a first_cursor_callback which is an async function that is called with - the journal cursor value found in the first log entry and awaited. + Returns a generator of (cursor, formatted_entry) tuples. """ match log_formatter: case LogFormatter.PLAIN: @@ -77,8 +74,6 @@ async def journal_logs_reader( case _: raise ValueError(f"Unknown log format: {log_formatter}") - call_cursor_callback = first_cursor_callback is not None - async with journal_logs as resp: entries: dict[str, str] = {} while not resp.content.at_eof(): @@ -87,10 +82,7 @@ async def journal_logs_reader( # at EOF (likely race between at_eof and EOF check in readuntil) if line == b"\n" or not line: if entries: - if call_cursor_callback: - await first_cursor_callback(entries.get("__CURSOR")) - call_cursor_callback = False - yield formatter_(entries) + yield entries.get("__CURSOR"), formatter_(entries) entries = {} continue diff --git a/tests/api/test_host.py b/tests/api/test_host.py index fb02d6400b8..9c5d2901677 100644 --- a/tests/api/test_host.py +++ b/tests/api/test_host.py @@ -262,7 +262,7 @@ async def test_advaced_logs_query_parameters( range_header=DEFAULT_RANGE, accept=LogFormat.JOURNAL, ) - journal_logs_reader.assert_called_with(ANY, LogFormatter.VERBOSE, ANY) + journal_logs_reader.assert_called_with(ANY, LogFormatter.VERBOSE) journal_logs_reader.reset_mock() journald_logs.reset_mock() @@ -280,7 +280,7 @@ async def test_advaced_logs_query_parameters( range_header="entries=:-53:", accept=LogFormat.JOURNAL, ) - journal_logs_reader.assert_called_with(ANY, LogFormatter.VERBOSE, ANY) + journal_logs_reader.assert_called_with(ANY, LogFormatter.VERBOSE) async def test_advanced_logs_boot_id_offset( @@ -333,24 +333,24 @@ async def test_advanced_logs_formatters( """Test advanced logs formatters varying on Accept header.""" await api_client.get("/host/logs") - journal_logs_reader.assert_called_once_with(ANY, LogFormatter.VERBOSE, ANY) + journal_logs_reader.assert_called_once_with(ANY, LogFormatter.VERBOSE) journal_logs_reader.reset_mock() headers = {"Accept": "text/x-log"} await api_client.get("/host/logs", headers=headers) - journal_logs_reader.assert_called_once_with(ANY, LogFormatter.VERBOSE, ANY) + journal_logs_reader.assert_called_once_with(ANY, LogFormatter.VERBOSE) journal_logs_reader.reset_mock() await api_client.get("/host/logs/identifiers/test") - journal_logs_reader.assert_called_once_with(ANY, LogFormatter.PLAIN, ANY) + journal_logs_reader.assert_called_once_with(ANY, LogFormatter.PLAIN) journal_logs_reader.reset_mock() headers = {"Accept": "text/x-log"} await api_client.get("/host/logs/identifiers/test", headers=headers) - journal_logs_reader.assert_called_once_with(ANY, LogFormatter.VERBOSE, ANY) + journal_logs_reader.assert_called_once_with(ANY, LogFormatter.VERBOSE) async def test_advanced_logs_errors(api_client: TestClient): diff --git a/tests/host/test_logs.py b/tests/host/test_logs.py index 2017cc96bfc..86c7bf49b4d 100644 --- a/tests/host/test_logs.py +++ b/tests/host/test_logs.py @@ -40,9 +40,13 @@ async def test_logs(coresys: CoreSys, journald_gateway: MagicMock): journald_gateway.feed_eof() async with coresys.host.logs.journald_logs() as resp: - line = await anext( + cursor, line = await anext( journal_logs_reader(resp, log_formatter=LogFormatter.VERBOSE) ) + assert ( + cursor + == "s=83fee99ca0c3466db5fc120d52ca7dd8;i=203f2ce;b=f5a5c442fa6548cf97474d2d57c920b3;m=3191a3c620;t=612ccd299e7af;x=8675b540119d10bb" + ) assert ( line == "2024-03-04 02:52:56.193 homeassistant systemd[1]: Started Hostname Service." @@ -64,7 +68,11 @@ async def test_logs_coloured(coresys: CoreSys, journald_gateway: MagicMock): journald_gateway.feed_eof() async with coresys.host.logs.journald_logs() as resp: - line = await anext(journal_logs_reader(resp)) + cursor, line = await anext(journal_logs_reader(resp)) + assert ( + cursor + == "s=83fee99ca0c3466db5fc120d52ca7dd8;i=2049389;b=f5a5c442fa6548cf97474d2d57c920b3;m=4263828e8c;t=612dda478b01b;x=9ae12394c9326930" + ) assert ( line == "\x1b[32m24-03-04 23:56:56 INFO (MainThread) [__main__] Closing Supervisor\x1b[0m" diff --git a/tests/utils/test_systemd_journal.py b/tests/utils/test_systemd_journal.py index ab7e0ad1368..4afe6248ef4 100644 --- a/tests/utils/test_systemd_journal.py +++ b/tests/utils/test_systemd_journal.py @@ -1,7 +1,7 @@ """Test systemd journal utilities.""" import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import ANY, MagicMock import pytest @@ -89,7 +89,7 @@ async def test_parsing_simple(): """Test plain formatter.""" journal_logs, stream = _journal_logs_mock() stream.feed_data(b"MESSAGE=Hello, world!\n\n") - line = await anext(journal_logs_reader(journal_logs)) + _, line = await anext(journal_logs_reader(journal_logs)) assert line == "Hello, world!" @@ -103,7 +103,7 @@ async def test_parsing_verbose(): b"_PID=666\n" b"MESSAGE=Hello, world!\n\n" ) - line = await anext( + _, line = await anext( journal_logs_reader(journal_logs, log_formatter=LogFormatter.VERBOSE) ) assert line == "2013-09-17 07:32:51.000 homeassistant python[666]: Hello, world!" @@ -118,7 +118,7 @@ async def test_parsing_newlines_in_message(): b"AFTER=after\n\n" ) - line = await anext(journal_logs_reader(journal_logs)) + _, line = await anext(journal_logs_reader(journal_logs)) assert line == "Hello,\nworld!" @@ -135,8 +135,8 @@ async def test_parsing_newlines_in_multiple_fields(): b"AFTER=after\n\n" ) - assert await anext(journal_logs_reader(journal_logs)) == "Hello,\nworld!\n" - assert await anext(journal_logs_reader(journal_logs)) == "Hello,\nworld!" + assert await anext(journal_logs_reader(journal_logs)) == (ANY, "Hello,\nworld!\n") + assert await anext(journal_logs_reader(journal_logs)) == (ANY, "Hello,\nworld!") async def test_parsing_two_messages(): @@ -151,8 +151,8 @@ async def test_parsing_two_messages(): stream.feed_eof() reader = journal_logs_reader(journal_logs) - assert await anext(reader) == "Hello, world!" - assert await anext(reader) == "Hello again, world!" + assert await anext(reader) == (ANY, "Hello, world!") + assert await anext(reader) == (ANY, "Hello again, world!") with pytest.raises(StopAsyncIteration): await anext(reader) @@ -167,14 +167,15 @@ async def test_cursor_callback(): b"__CURSOR=cursor2\n" b"MESSAGE=Hello again, world!\n" b"ID=2\n\n" + b"MESSAGE=No cursor\n" + b"ID=2\n\n" ) stream.feed_eof() - cursor_callback = AsyncMock() - reader = journal_logs_reader(journal_logs, first_cursor_callback=cursor_callback) - assert await anext(reader) == "Hello, world!" - assert await anext(reader) == "Hello again, world!" - cursor_callback.assert_called_once_with("cursor1") + reader = journal_logs_reader(journal_logs) + assert await anext(reader) == ("cursor1", "Hello, world!") + assert await anext(reader) == ("cursor2", "Hello again, world!") + assert await anext(reader) == (None, "No cursor") with pytest.raises(StopAsyncIteration): await anext(reader) @@ -190,11 +191,9 @@ async def test_cursor_callback_no_cursor(): ) stream.feed_eof() - cursor_callback = AsyncMock() - reader = journal_logs_reader(journal_logs, first_cursor_callback=cursor_callback) - assert await anext(reader) == "Hello, world!" - assert await anext(reader) == "Hello again, world!" - cursor_callback.assert_called_once_with(None) + reader = journal_logs_reader(journal_logs) + assert await anext(reader) == (ANY, "Hello, world!") + assert await anext(reader) == (ANY, "Hello again, world!") with pytest.raises(StopAsyncIteration): await anext(reader) @@ -216,7 +215,7 @@ async def test_parsing_journal_host_logs(): """Test parsing of real host logs.""" journal_logs, stream = _journal_logs_mock() stream.feed_data(load_fixture("logs_export_host.txt").encode("utf-8")) - line = await anext(journal_logs_reader(journal_logs)) + _, line = await anext(journal_logs_reader(journal_logs)) assert line == "Started Hostname Service." @@ -224,7 +223,7 @@ async def test_parsing_colored_supervisor_logs(): """Test parsing of real logs with ANSI escape sequences.""" journal_logs, stream = _journal_logs_mock() stream.feed_data(load_fixture("logs_export_supervisor.txt").encode("utf-8")) - line = await anext(journal_logs_reader(journal_logs)) + _, line = await anext(journal_logs_reader(journal_logs)) assert ( line == "\x1b[32m24-03-04 23:56:56 INFO (MainThread) [__main__] Closing Supervisor\x1b[0m"