Skip to content

Commit

Permalink
on super class (Textualize#2746)
Browse files Browse the repository at this point in the history
* on super class

* simplification

* simplify

* remove whitespace

* changelog

* changelog

* Update tests/test_on.py

Co-authored-by: darrenburns <[email protected]>

---------

Co-authored-by: darrenburns <[email protected]>
  • Loading branch information
willmcgugan and darrenburns authored Jun 7, 2023
1 parent c5253f4 commit 8947dbe
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 28 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Fixed setting `TreeNode.label` on an existing `Tree` node not immediately https://github.com/Textualize/textual/pull/2713
- Correctly implement `__eq__` protocol in DataTable https://github.com/Textualize/textual/pull/2705

### Changed

- Breaking change: The `@on` decorator will now match a message class and any child classes https://github.com/Textualize/textual/pull/2746

## [0.27.0] - 2023-06-01

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/textual/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class Mount(Event, bubble=False, verbose=False):
"""


class Unmount(Mount, bubble=False, verbose=False):
class Unmount(Event, bubble=False, verbose=False):
"""Sent when a widget is unmounted and may not longer receive messages.
- [ ] Bubbles
Expand Down
59 changes: 38 additions & 21 deletions src/textual/message_pump.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from asyncio import CancelledError, Queue, QueueEmpty, Task
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable, cast
from weakref import WeakSet

from . import Logger, events, log, messages
Expand Down Expand Up @@ -65,7 +65,10 @@ def __new__(

for value in class_dict.values():
if callable(value) and hasattr(value, "_textual_on"):
for message_type, selectors in getattr(value, "_textual_on"):
textual_on: list[
tuple[type[Message], dict[str, tuple[SelectorSet, ...]]]
] = getattr(value, "_textual_on")
for message_type, selectors in textual_on:
handlers.setdefault(message_type, []).append((value, selectors))
if isclass(value) and issubclass(value, Message):
if "namespace" in value.__dict__:
Expand Down Expand Up @@ -591,31 +594,45 @@ def _get_dispatch_methods(
method_name: Handler method name.
message: Message object.
"""
from .widget import Widget

methods_dispatched: set[Callable] = set()
message_mro = [
_type for _type in message.__class__.__mro__ if issubclass(_type, Message)
]
for cls in self.__class__.__mro__:
if message._no_default_action:
break
# Try decorated handlers first
decorated_handlers = cls.__dict__.get("_decorated_handlers")
if decorated_handlers is not None:
handlers = decorated_handlers.get(type(message), [])
from .widget import Widget

for method, selectors in handlers:
if not selectors:
yield cls, method.__get__(self, cls)
else:
if not message._sender:
decorated_handlers = cast(
"dict[type[Message], list[tuple[Callable, dict[str, tuple[SelectorSet, ...]]]]] | None",
cls.__dict__.get("_decorated_handlers"),
)

if decorated_handlers:
for message_class in message_mro:
handlers = decorated_handlers.get(message_class, [])

for method, selectors in handlers:
if method in methods_dispatched:
continue
for attribute, selector in selectors.items():
node = getattr(message, attribute)
if not isinstance(node, Widget):
raise OnNoWidget(
f"on decorator can't match against {attribute!r} as it is not a widget."
)
if not match(selector, node):
break
else:
if not selectors:
yield cls, method.__get__(self, cls)
methods_dispatched.add(method)
else:
if not message._sender:
continue
for attribute, selector in selectors.items():
node = getattr(message, attribute)
if not isinstance(node, Widget):
raise OnNoWidget(
f"on decorator can't match against {attribute!r} as it is not a widget."
)
if not match(selector, node):
break
else:
yield cls, method.__get__(self, cls)
methods_dispatched.add(method)

# Fall back to the naming convention
# But avoid calling the handler if it was decorated
Expand Down
14 changes: 8 additions & 6 deletions tests/selection_list/test_selection_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from textual import on
from textual.app import App, ComposeResult
from textual.messages import Message
from textual.widgets import OptionList, SelectionList


Expand All @@ -24,12 +23,15 @@ def __init__(self) -> None:
def compose(self) -> ComposeResult:
yield SelectionList[int](*[(str(n), n) for n in range(10)])

@on(OptionList.OptionHighlighted)
@on(OptionList.OptionSelected)
@on(SelectionList.SelectionHighlighted)
@on(SelectionList.SelectionToggled)
@on(OptionList.OptionMessage)
@on(SelectionList.SelectionMessage)
@on(SelectionList.SelectedChanged)
def _record(self, event: Message) -> None:
def _record(
self,
event: OptionList.OptionMessage
| SelectionList.SelectionMessage
| SelectionList.SelectedChanged,
) -> None:
assert event.control == self.query_one(SelectionList)
self.messages.append(
(
Expand Down
210 changes: 210 additions & 0 deletions tests/test_on.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass

import pytest

from textual import on
Expand Down Expand Up @@ -143,3 +147,209 @@ def two(self) -> None:
await pilot.press("tab", "right", "right")

assert log == ["one", "two"]


class MessageSender(Widget):
@dataclass
class Parent(Message):
sender: MessageSender

@property
def control(self) -> MessageSender:
return self.sender

class Child(Parent):
pass

def post_parent(self) -> None:
self.post_message(self.Parent(self))

def post_child(self) -> None:
self.post_message(self.Child(self))


async def test_fire_on_inherited_message() -> None:
"""Handlers should fire when descendant messages are posted."""

posted: list[str] = []

class InheritTestApp(App[None]):
def compose(self) -> ComposeResult:
yield MessageSender()

@on(MessageSender.Parent)
def catch_parent(self) -> None:
posted.append("parent")

@on(MessageSender.Child)
def catch_child(self) -> None:
posted.append("child")

def on_mount(self) -> None:
self.query_one(MessageSender).post_parent()
self.query_one(MessageSender).post_child()


async with InheritTestApp().run_test():
pass

assert posted == ["parent", "child", "parent"]


async def test_fire_inherited_on_single_handler() -> None:
"""Test having parent/child messages on a single handler."""

posted: list[str] = []

class InheritTestApp(App[None]):
def compose(self) -> ComposeResult:
yield MessageSender()

@on(MessageSender.Parent)
@on(MessageSender.Child)
def catch_either(self, event: MessageSender.Parent) -> None:
posted.append(f"either {event.__class__.__name__}")

def on_mount(self) -> None:
self.query_one(MessageSender).post_parent()
self.query_one(MessageSender).post_child()

async with InheritTestApp().run_test():
pass

assert posted == ["either Parent", "either Child"]


async def test_fire_inherited_on_single_handler_multi_selector() -> None:
"""Test having parent/child messages on a single handler but with different selectors."""

posted: list[str] = []

class InheritTestApp(App[None]):
def compose(self) -> ComposeResult:
yield MessageSender(classes="a b")

@on(MessageSender.Parent, ".y")
@on(MessageSender.Child, ".y")
@on(MessageSender.Parent, ".a.b")
@on(MessageSender.Child, ".a.b")
@on(MessageSender.Parent, ".a")
@on(MessageSender.Child, ".a")
@on(MessageSender.Parent, ".b")
@on(MessageSender.Child, ".b")
@on(MessageSender.Parent, ".x")
@on(MessageSender.Child, ".x")
def catch_either(self, event: MessageSender.Parent) -> None:
posted.append(f"either {event.__class__.__name__}")

@on(MessageSender.Child, ".a, .x")
def catch_selector_list_one_miss(self, event: MessageSender.Parent) -> None:
posted.append(f"selector list one miss {event.__class__.__name__}")

@on(MessageSender.Child, ".a, .b")
def catch_selector_list_two_hits(self, event: MessageSender.Parent) -> None:
posted.append(f"selector list two hits {event.__class__.__name__}")

@on(MessageSender.Child, ".a.b")
def catch_selector_combined_hits(self, event: MessageSender.Parent) -> None:
posted.append(f"combined hits {event.__class__.__name__}")

@on(MessageSender.Child, ".a.x")
def catch_selector_combined_miss(self, event: MessageSender.Parent) -> None:
posted.append(f"combined miss {event.__class__.__name__}")

def on_mount(self) -> None:
self.query_one(MessageSender).post_parent()
self.query_one(MessageSender).post_child()

async with InheritTestApp().run_test():
pass

assert posted == [
"either Parent",
"either Child",
"selector list one miss Child",
"selector list two hits Child",
"combined hits Child",
]


async def test_fire_inherited_and_on_methods() -> None:
posted: list[str] = []

class OnAndOnTestApp(App[None]):
def compose(self) -> ComposeResult:
yield MessageSender()

def on_message_sender_parent(self) -> None:
posted.append("on_message_sender_parent")

@on(MessageSender.Parent)
def catch_parent(self) -> None:
posted.append("catch_parent")

def on_message_sender_child(self) -> None:
posted.append("on_message_sender_child")

@on(MessageSender.Child)
def catch_child(self) -> None:
posted.append("catch_child")

def on_mount(self) -> None:
self.query_one(MessageSender).post_parent()
self.query_one(MessageSender).post_child()

async with OnAndOnTestApp().run_test():
pass

assert posted == [
"catch_parent",
"on_message_sender_parent",
"catch_child",
"catch_parent",
"on_message_sender_child",
]


class MixinMessageSender(Widget):
class Parent(Message):
pass

class JustSomeRandomMixin:
pass

class Child(JustSomeRandomMixin, Parent):
pass

def post_parent(self) -> None:
self.post_message(self.Parent())

def post_child(self) -> None:
self.post_message(self.Child())


async def test_fire_on_inherited_message_plus_mixins() -> None:
"""Handlers should fire when descendant messages are posted, without mixins messing things up."""

posted: list[str] = []

class InheritTestApp(App[None]):
def compose(self) -> ComposeResult:
yield MixinMessageSender()

@on(MixinMessageSender.Parent)
def catch_parent(self) -> None:
posted.append("parent")

@on(MixinMessageSender.Child)
def catch_child(self) -> None:
posted.append("child")

def on_mount(self) -> None:
self.query_one(MixinMessageSender).post_parent()
self.query_one(MixinMessageSender).post_child()

async with InheritTestApp().run_test():
pass

assert posted == ["parent", "child", "parent"]

0 comments on commit 8947dbe

Please sign in to comment.