diff --git a/src/textual/_on.py b/src/textual/_on.py index 30689a0522..32554689ac 100644 --- a/src/textual/_on.py +++ b/src/textual/_on.py @@ -13,22 +13,37 @@ class OnDecoratorError(Exception): """Errors related to the `on` decorator. Typically raised at import time as an early warning system. - """ def on( - message_type: type[Message], selector: str | None = None + message_type: type[Message], selector: str | None = None, **kwargs: str ) -> Callable[[DecoratedType], DecoratedType]: - """Decorator to declare method is a message handler. + """Decorator to declare that method is a message handler. + + The decorator can take a CSS selector that is applied to the attribute `control` + of the message. Example: ```python + # Handle the press of buttons with ID "#quit". @on(Button.Pressed, "#quit") def quit_button(self) -> None: self.app.quit() ``` + Additionally, arbitrary keyword arguments can be used to provide further selectors + for arbitrary attributes of the messages passed in. + + Example: + ```python + # Handle the activation of the tab "#home" within the `TabbedContent` "#tabs". + @on(TabbedContent.TabActivated, "#tabs", tab="#home") + def switch_to_home(self) -> None: + self.log("Switching back to the home tab.") + ... + ``` + Args: message_type: The message type (i.e. the class). selector: An optional [selector](/guide/CSS#selectors). If supplied, the handler will only be called if `selector` @@ -40,12 +55,18 @@ def quit_button(self) -> None: "The 'selector' argument requires a message class with a 'control' attribute (such as events from controls)." ) + selectors: dict[str, str] = {} if selector is not None: + selectors["control"] = selector + if kwargs: + selectors.update(kwargs) + + for attribute, css_selector in selectors.items(): try: - parse_selectors(selector) - except TokenError as error: + parse_selectors(css_selector) + except TokenError: raise OnDecoratorError( - f"Unable to parse selector {selector!r}; check for syntax errors" + f"Unable to parse selector {css_selector!r} for {attribute}; check for syntax errors" ) from None def decorator(method: DecoratedType) -> DecoratedType: @@ -53,7 +74,7 @@ def decorator(method: DecoratedType) -> DecoratedType: if not hasattr(method, "_textual_on"): setattr(method, "_textual_on", []) - getattr(method, "_textual_on").append((message_type, selector)) + getattr(method, "_textual_on").append((message_type, selectors)) return method diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 18578606af..fadf6994a2 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -60,15 +60,15 @@ def __new__( namespace = camel_to_snake(name) isclass = inspect.isclass handlers: dict[ - type[Message], list[tuple[Callable, str | None]] + type[Message], list[tuple[Callable, dict[str, str]]] ] = class_dict.get("_decorated_handlers", {}) class_dict["_decorated_handlers"] = handlers for value in class_dict.values(): if callable(value) and hasattr(value, "_textual_on"): - for message_type, selector in getattr(value, "_textual_on"): - handlers.setdefault(message_type, []).append((value, selector)) + for message_type, selectors in getattr(value, "_textual_on"): + handlers.setdefault(message_type, []).append((value, selectors)) if isclass(value) and issubclass(value, Message): if "namespace" not in value.__dict__: value.namespace = namespace @@ -563,14 +563,25 @@ def _get_dispatch_methods( decorated_handlers = cls.__dict__.get("_decorated_handlers") if decorated_handlers is not None: handlers = decorated_handlers.get(type(message), []) - for method, selector in handlers: - if selector is None: + _sentinel = object() + for method, selectors in handlers: + if not selectors: yield cls, method.__get__(self, cls) else: - selector_sets = parse_selectors(selector) - if message._sender is not None and match( - selector_sets, message.control - ): + print("===") + print(message) + print(selectors) + if not message._sender: + continue + for attribute, selector in selectors.items(): + node = getattr(message, attribute, _sentinel) + print(f"Matching {node} against {selector}") + if node is _sentinel: + break + if not match(parse_selectors(selector), node): + break + print("passed") + else: yield cls, method.__get__(self, cls) # Fall back to the naming convention