diff --git a/CHANGELOG.md b/CHANGELOG.md index ce7c782824..472c4bfda7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fixed setting `TreeNode.label` on an existing `Tree` node not immediately https://github.com/Textualize/textual/pull/2713 - Correctly implement `__eq__` protocol in DataTable https://github.com/Textualize/textual/pull/2705 +- Fixed exceptions in Pilot tests being silently ignored https://github.com/Textualize/textual/pull/2754 +- Fixed issue where internal data of `OptionList` could be invalid for short window after `clear_options` https://github.com/Textualize/textual/pull/2754 - Fixed `Tooltip` causing a `query_one` on a lone `Static` to fail https://github.com/Textualize/textual/issues/2723 ### Changed diff --git a/src/textual/app.py b/src/textual/app.py index 4aed6b0394..a48aae41f7 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -183,7 +183,6 @@ class CssPathError(Exception): ReturnType = TypeVar("ReturnType") - CSSPathType = Union[ str, PurePath, @@ -367,6 +366,13 @@ def __init__( self._animate = self._animator.bind(self) self.mouse_position = Offset(0, 0) + self._exception: Exception | None = None + """The unhandled exception which is leading to the app shutting down, + or None if the app is still running with no unhandled exceptions.""" + + self._exception_event: asyncio.Event = asyncio.Event() + """An event that will be set when the first exception is encountered.""" + self.title = ( self.TITLE if self.TITLE is not None else f"{self.__class__.__name__}" ) @@ -1108,6 +1114,9 @@ async def run_app(app) -> None: # Shutdown the app cleanly await app._shutdown() await app_task + # Re-raise the exception which caused panic so test frameworks are aware + if self._exception: + raise self._exception async def run_async( self, @@ -1782,9 +1791,17 @@ def render(renderable: RenderableType) -> list[Segment]: def _handle_exception(self, error: Exception) -> None: """Called with an unhandled exception. + Always results in the app exiting. + Args: error: An exception instance. """ + # If we're running via pilot and this is the first exception encountered, + # take note of it so that we can re-raise for test frameworks later. + if self.is_headless and self._exception is None: + self._exception = error + self._exception_event.set() + if hasattr(error, "__rich__"): # Exception has a rich method, so we can defer to that for the rendering self.panic(error) diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 38dd851336..ca7aa0047f 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -494,6 +494,7 @@ async def _process_messages_loop(self) -> None: """Process messages until the queue is closed.""" _rich_traceback_guard = True self._thread_id = threading.get_ident() + while not self._closed: try: message = await self._get_message() diff --git a/src/textual/pilot.py b/src/textual/pilot.py index 041e00e13e..64cd602bf2 100644 --- a/src/textual/pilot.py +++ b/src/textual/pilot.py @@ -42,6 +42,10 @@ def _get_mouse_message_arguments( return message_arguments +class WaitForScreenTimeout(Exception): + pass + + @rich.repr.auto(angular=True) class Pilot(Generic[ReturnType]): """Pilot object to drive an app.""" @@ -134,13 +138,17 @@ async def hover( await self.pause() async def _wait_for_screen(self, timeout: float = 30.0) -> bool: - """Wait for the current screen to have processed all pending events. + """Wait for the current screen and its children to have processed all pending events. Args: timeout: A timeout in seconds to wait. Returns: - `True` if all events were processed, or `False` if the wait timed out. + `True` if all events were processed. `False` if an exception occurred, + meaning that not all events could be processed. + + Raises: + WaitForScreenTimeout: If the screen and its children didn't finish processing within the timeout. """ children = [self.app, *self.app.screen.walk_children(with_self=True)] count = 0 @@ -160,10 +168,29 @@ def decrement_counter() -> None: count += 1 if count: - # Wait for the count to return to zero, or a timeout - try: - await asyncio.wait_for(count_zero_event.wait(), timeout=timeout) - except asyncio.TimeoutError: + # Wait for the count to return to zero, or a timeout, or an exception + wait_for = [ + asyncio.create_task(count_zero_event.wait()), + asyncio.create_task(self.app._exception_event.wait()), + ] + _, pending = await asyncio.wait( + wait_for, + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + + timed_out = len(wait_for) == len(pending) + if timed_out: + raise WaitForScreenTimeout( + "Timed out while waiting for widgets to process pending messages." + ) + + # We've either timed out, encountered an exception, or we've finished + # decrementing all the counters (all events processed in children). + if count > 0: return False return True diff --git a/src/textual/widgets/_option_list.py b/src/textual/widgets/_option_list.py index 3b365fa7bf..65fd77470d 100644 --- a/src/textual/widgets/_option_list.py +++ b/src/textual/widgets/_option_list.py @@ -630,14 +630,7 @@ def clear_options(self) -> Self: self.highlighted = None self._mouse_hovering_over = None self.virtual_size = Size(self.scrollable_content_region.width, 0) - # TODO: See https://github.com/Textualize/textual/issues/2582 -- it - # should not be necessary to do this like this here; ideally here in - # clear_options it would be a forced refresh, and also in a - # `on_show` it would be the same (which, I think, would actually - # solve the problem we're seeing). But, until such a time as we get - # to the bottom of 2582... this seems to delay the refresh enough - # that things fall into place. - self._request_content_tracking_refresh() + self._refresh_content_tracking(force=True) return self def _set_option_disabled(self, index: int, disabled: bool) -> Self: diff --git a/src/textual/widgets/_tabs.py b/src/textual/widgets/_tabs.py index 3efab6d921..5e4e312bc7 100644 --- a/src/textual/widgets/_tabs.py +++ b/src/textual/widgets/_tabs.py @@ -422,9 +422,7 @@ def watch_active(self, previously_active: str, active: str) -> None: active_tab = self.query_one(f"#tabs-list > #{active}", Tab) self.query("#tabs-list > Tab.-active").remove_class("-active") active_tab.add_class("-active") - self.call_after_refresh( - self._highlight_active, animate=previously_active != "" - ) + self.call_later(self._highlight_active, animate=previously_active != "") self.post_message(self.TabActivated(self, active_tab)) else: underline = self.query_one(Underline) diff --git a/tests/test_pilot.py b/tests/test_pilot.py index 04491dea9a..d631146c77 100644 --- a/tests/test_pilot.py +++ b/tests/test_pilot.py @@ -1,7 +1,11 @@ from string import punctuation +import pytest + from textual import events -from textual.app import App +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.widgets import Label KEY_CHARACTERS_TO_TEST = "akTW03" + punctuation """Test some "simple" characters (letters + digits) and all punctuation.""" @@ -19,3 +23,32 @@ def on_key(self, event: events.Key) -> None: for char in KEY_CHARACTERS_TO_TEST: await pilot.press(char) assert keys_pressed[-1] == char + + +async def test_pilot_exception_catching_compose(): + """Ensuring that test frameworks are aware of exceptions + inside compose methods when running via Pilot run_test().""" + + class FailingApp(App): + def compose(self) -> ComposeResult: + 1 / 0 + yield Label("Beep") + + with pytest.raises(ZeroDivisionError): + async with FailingApp().run_test(): + pass + + +async def test_pilot_exception_catching_action(): + """Ensure that exceptions inside action handlers are presented + to the test framework when running via Pilot run_test().""" + + class FailingApp(App): + BINDINGS = [Binding("b", "beep", "beep")] + + def action_beep(self) -> None: + 1 / 0 + + with pytest.raises(ZeroDivisionError): + async with FailingApp().run_test() as pilot: + await pilot.press("b") diff --git a/tests/test_screen_modes.py b/tests/test_screen_modes.py index b7b34b133e..7ce402fc22 100644 --- a/tests/test_screen_modes.py +++ b/tests/test_screen_modes.py @@ -181,7 +181,7 @@ def compose(self) -> ComposeResult: yield Label("fast") def on_mount(self) -> None: - self.set_interval(0.01, self.ping) + self.call_later(self.set_interval, 0.01, self.ping) def ping(self) -> None: pings.append(str(self.app.query_one(Label).renderable))