diff --git a/CHANGELOG.md b/CHANGELOG.md index 45d6204bc3..23739959d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fixed setting `TreeNode.label` on an existing `Tree` node not immediately https://github.com/Textualize/textual/pull/2713 - Correctly implement `__eq__` protocol in DataTable https://github.com/Textualize/textual/pull/2705 +### Changed + +- Breaking change: The `@on` decorator will now match a message class and any child classes https://github.com/Textualize/textual/pull/2746 + ## [0.27.0] - 2023-06-01 ### Fixed diff --git a/src/textual/events.py b/src/textual/events.py index 59fd42d5a6..4e2523535d 100644 --- a/src/textual/events.py +++ b/src/textual/events.py @@ -138,7 +138,7 @@ class Mount(Event, bubble=False, verbose=False): """ -class Unmount(Mount, bubble=False, verbose=False): +class Unmount(Event, bubble=False, verbose=False): """Sent when a widget is unmounted and may not longer receive messages. - [ ] Bubbles diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 2c191d3de7..37251b9a32 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -10,7 +10,7 @@ from asyncio import CancelledError, Queue, QueueEmpty, Task from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable, cast from weakref import WeakSet from . import Logger, events, log, messages @@ -65,7 +65,10 @@ def __new__( for value in class_dict.values(): if callable(value) and hasattr(value, "_textual_on"): - for message_type, selectors in getattr(value, "_textual_on"): + textual_on: list[ + tuple[type[Message], dict[str, tuple[SelectorSet, ...]]] + ] = getattr(value, "_textual_on") + for message_type, selectors in textual_on: handlers.setdefault(message_type, []).append((value, selectors)) if isclass(value) and issubclass(value, Message): if "namespace" in value.__dict__: @@ -591,31 +594,45 @@ def _get_dispatch_methods( method_name: Handler method name. message: Message object. """ + from .widget import Widget + + methods_dispatched: set[Callable] = set() + message_mro = [ + _type for _type in message.__class__.__mro__ if issubclass(_type, Message) + ] for cls in self.__class__.__mro__: if message._no_default_action: break # Try decorated handlers first - decorated_handlers = cls.__dict__.get("_decorated_handlers") - if decorated_handlers is not None: - handlers = decorated_handlers.get(type(message), []) - from .widget import Widget - - for method, selectors in handlers: - if not selectors: - yield cls, method.__get__(self, cls) - else: - if not message._sender: + decorated_handlers = cast( + "dict[type[Message], list[tuple[Callable, dict[str, tuple[SelectorSet, ...]]]]] | None", + cls.__dict__.get("_decorated_handlers"), + ) + + if decorated_handlers: + for message_class in message_mro: + handlers = decorated_handlers.get(message_class, []) + + for method, selectors in handlers: + if method in methods_dispatched: continue - for attribute, selector in selectors.items(): - node = getattr(message, attribute) - if not isinstance(node, Widget): - raise OnNoWidget( - f"on decorator can't match against {attribute!r} as it is not a widget." - ) - if not match(selector, node): - break - else: + if not selectors: yield cls, method.__get__(self, cls) + methods_dispatched.add(method) + else: + if not message._sender: + continue + for attribute, selector in selectors.items(): + node = getattr(message, attribute) + if not isinstance(node, Widget): + raise OnNoWidget( + f"on decorator can't match against {attribute!r} as it is not a widget." + ) + if not match(selector, node): + break + else: + yield cls, method.__get__(self, cls) + methods_dispatched.add(method) # Fall back to the naming convention # But avoid calling the handler if it was decorated diff --git a/tests/selection_list/test_selection_messages.py b/tests/selection_list/test_selection_messages.py index d90f04c6d3..d8717d8c84 100644 --- a/tests/selection_list/test_selection_messages.py +++ b/tests/selection_list/test_selection_messages.py @@ -10,7 +10,6 @@ from textual import on from textual.app import App, ComposeResult -from textual.messages import Message from textual.widgets import OptionList, SelectionList @@ -24,12 +23,15 @@ def __init__(self) -> None: def compose(self) -> ComposeResult: yield SelectionList[int](*[(str(n), n) for n in range(10)]) - @on(OptionList.OptionHighlighted) - @on(OptionList.OptionSelected) - @on(SelectionList.SelectionHighlighted) - @on(SelectionList.SelectionToggled) + @on(OptionList.OptionMessage) + @on(SelectionList.SelectionMessage) @on(SelectionList.SelectedChanged) - def _record(self, event: Message) -> None: + def _record( + self, + event: OptionList.OptionMessage + | SelectionList.SelectionMessage + | SelectionList.SelectedChanged, + ) -> None: assert event.control == self.query_one(SelectionList) self.messages.append( ( diff --git a/tests/test_on.py b/tests/test_on.py index 740af6a944..2bfaf476a6 100644 --- a/tests/test_on.py +++ b/tests/test_on.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from dataclasses import dataclass + import pytest from textual import on @@ -143,3 +147,209 @@ def two(self) -> None: await pilot.press("tab", "right", "right") assert log == ["one", "two"] + + +class MessageSender(Widget): + @dataclass + class Parent(Message): + sender: MessageSender + + @property + def control(self) -> MessageSender: + return self.sender + + class Child(Parent): + pass + + def post_parent(self) -> None: + self.post_message(self.Parent(self)) + + def post_child(self) -> None: + self.post_message(self.Child(self)) + + +async def test_fire_on_inherited_message() -> None: + """Handlers should fire when descendant messages are posted.""" + + posted: list[str] = [] + + class InheritTestApp(App[None]): + def compose(self) -> ComposeResult: + yield MessageSender() + + @on(MessageSender.Parent) + def catch_parent(self) -> None: + posted.append("parent") + + @on(MessageSender.Child) + def catch_child(self) -> None: + posted.append("child") + + def on_mount(self) -> None: + self.query_one(MessageSender).post_parent() + self.query_one(MessageSender).post_child() + + + async with InheritTestApp().run_test(): + pass + + assert posted == ["parent", "child", "parent"] + + +async def test_fire_inherited_on_single_handler() -> None: + """Test having parent/child messages on a single handler.""" + + posted: list[str] = [] + + class InheritTestApp(App[None]): + def compose(self) -> ComposeResult: + yield MessageSender() + + @on(MessageSender.Parent) + @on(MessageSender.Child) + def catch_either(self, event: MessageSender.Parent) -> None: + posted.append(f"either {event.__class__.__name__}") + + def on_mount(self) -> None: + self.query_one(MessageSender).post_parent() + self.query_one(MessageSender).post_child() + + async with InheritTestApp().run_test(): + pass + + assert posted == ["either Parent", "either Child"] + + +async def test_fire_inherited_on_single_handler_multi_selector() -> None: + """Test having parent/child messages on a single handler but with different selectors.""" + + posted: list[str] = [] + + class InheritTestApp(App[None]): + def compose(self) -> ComposeResult: + yield MessageSender(classes="a b") + + @on(MessageSender.Parent, ".y") + @on(MessageSender.Child, ".y") + @on(MessageSender.Parent, ".a.b") + @on(MessageSender.Child, ".a.b") + @on(MessageSender.Parent, ".a") + @on(MessageSender.Child, ".a") + @on(MessageSender.Parent, ".b") + @on(MessageSender.Child, ".b") + @on(MessageSender.Parent, ".x") + @on(MessageSender.Child, ".x") + def catch_either(self, event: MessageSender.Parent) -> None: + posted.append(f"either {event.__class__.__name__}") + + @on(MessageSender.Child, ".a, .x") + def catch_selector_list_one_miss(self, event: MessageSender.Parent) -> None: + posted.append(f"selector list one miss {event.__class__.__name__}") + + @on(MessageSender.Child, ".a, .b") + def catch_selector_list_two_hits(self, event: MessageSender.Parent) -> None: + posted.append(f"selector list two hits {event.__class__.__name__}") + + @on(MessageSender.Child, ".a.b") + def catch_selector_combined_hits(self, event: MessageSender.Parent) -> None: + posted.append(f"combined hits {event.__class__.__name__}") + + @on(MessageSender.Child, ".a.x") + def catch_selector_combined_miss(self, event: MessageSender.Parent) -> None: + posted.append(f"combined miss {event.__class__.__name__}") + + def on_mount(self) -> None: + self.query_one(MessageSender).post_parent() + self.query_one(MessageSender).post_child() + + async with InheritTestApp().run_test(): + pass + + assert posted == [ + "either Parent", + "either Child", + "selector list one miss Child", + "selector list two hits Child", + "combined hits Child", + ] + + +async def test_fire_inherited_and_on_methods() -> None: + posted: list[str] = [] + + class OnAndOnTestApp(App[None]): + def compose(self) -> ComposeResult: + yield MessageSender() + + def on_message_sender_parent(self) -> None: + posted.append("on_message_sender_parent") + + @on(MessageSender.Parent) + def catch_parent(self) -> None: + posted.append("catch_parent") + + def on_message_sender_child(self) -> None: + posted.append("on_message_sender_child") + + @on(MessageSender.Child) + def catch_child(self) -> None: + posted.append("catch_child") + + def on_mount(self) -> None: + self.query_one(MessageSender).post_parent() + self.query_one(MessageSender).post_child() + + async with OnAndOnTestApp().run_test(): + pass + + assert posted == [ + "catch_parent", + "on_message_sender_parent", + "catch_child", + "catch_parent", + "on_message_sender_child", + ] + + +class MixinMessageSender(Widget): + class Parent(Message): + pass + + class JustSomeRandomMixin: + pass + + class Child(JustSomeRandomMixin, Parent): + pass + + def post_parent(self) -> None: + self.post_message(self.Parent()) + + def post_child(self) -> None: + self.post_message(self.Child()) + + +async def test_fire_on_inherited_message_plus_mixins() -> None: + """Handlers should fire when descendant messages are posted, without mixins messing things up.""" + + posted: list[str] = [] + + class InheritTestApp(App[None]): + def compose(self) -> ComposeResult: + yield MixinMessageSender() + + @on(MixinMessageSender.Parent) + def catch_parent(self) -> None: + posted.append("parent") + + @on(MixinMessageSender.Child) + def catch_child(self) -> None: + posted.append("child") + + def on_mount(self) -> None: + self.query_one(MixinMessageSender).post_parent() + self.query_one(MixinMessageSender).post_child() + + async with InheritTestApp().run_test(): + pass + + assert posted == ["parent", "child", "parent"]