diff --git a/sandbox/will/screens_focus.css b/sandbox/will/screens_focus.css new file mode 100644 index 0000000000..dd2a3ab260 --- /dev/null +++ b/sandbox/will/screens_focus.css @@ -0,0 +1,9 @@ + Focusable { + padding: 3 6; + background: blue 20%; + } + + Focusable :focus { + border: solid red; + } + diff --git a/sandbox/will/screens_focus.py b/sandbox/will/screens_focus.py new file mode 100644 index 0000000000..2d35f54705 --- /dev/null +++ b/sandbox/will/screens_focus.py @@ -0,0 +1,20 @@ +from textual.app import App, ComposeResult +from textual.widgets import Static, Footer + + +class Focusable(Static, can_focus=True): + pass + + +class ScreensFocusApp(App): + def compose(self) -> ComposeResult: + yield Focusable("App - one") + yield Focusable("App - two") + yield Focusable("App - three") + yield Focusable("App - four") + yield Footer() + + +app = ScreensFocusApp(css_path="screens_focus.css") +if __name__ == "__main__": + app.run() diff --git a/src/textual/app.py b/src/textual/app.py index d286dcffe5..7cf010f3c1 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -37,7 +37,7 @@ from .design import ColorSystem from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog from .devtools.redirect_output import StdoutRedirector -from .dom import DOMNode +from .dom import DOMNode, NoScreen from .driver import Driver from .drivers.headless_driver import HeadlessDriver from .features import FeatureFlag, parse_features @@ -1142,7 +1142,10 @@ def _unregister(self, widget: Widget) -> None: Args: widget (Widget): A Widget to unregister """ - widget.screen._reset_focus(widget) + try: + widget.screen._reset_focus(widget) + except NoScreen: + pass if isinstance(widget._parent, Widget): widget._parent.children._remove(widget) @@ -1394,7 +1397,9 @@ async def _on_remove(self, event: events.Remove) -> None: if parent is not None: parent.refresh(layout=True) - remove_widgets = list(widget.walk_children(Widget, with_self=True)) + remove_widgets = widget.walk_children( + Widget, with_self=True, method="depth", reverse=True + ) for child in remove_widgets: self._unregister(child) for child in remove_widgets: diff --git a/src/textual/css/styles.py b/src/textual/css/styles.py index d59bea2f66..dd4b8d4b63 100644 --- a/src/textual/css/styles.py +++ b/src/textual/css/styles.py @@ -573,7 +573,7 @@ def refresh(self, *, layout: bool = False, children: bool = False) -> None: if self.node is not None: self.node.refresh(layout=layout) if children: - for child in self.node.walk_children(with_self=False): + for child in self.node.walk_children(with_self=False, reverse=True): child.refresh(layout=layout) def reset(self) -> None: diff --git a/src/textual/dom.py b/src/textual/dom.py index f3234ac866..cc3f0866e2 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -1,16 +1,18 @@ from __future__ import annotations -from inspect import getfile import re +import sys +from collections import deque +from inspect import getfile from typing import ( - cast, + TYPE_CHECKING, ClassVar, Iterable, Iterator, Type, - overload, TypeVar, - TYPE_CHECKING, + cast, + overload, ) import rich.repr @@ -23,14 +25,14 @@ from ._context import NoActiveAppError from ._node_list import NodeList from .binding import Bindings, BindingType -from .color import Color, WHITE, BLACK +from .color import BLACK, WHITE, Color from .css._error_tools import friendly_list from .css.constants import VALID_DISPLAY, VALID_VISIBILITY -from .css.errors import StyleValueError, DeclarationError +from .css.errors import DeclarationError, StyleValueError from .css.parse import parse_declarations -from .css.styles import Styles, RenderStyles -from .css.tokenize import IDENTIFIER from .css.query import NoMatches +from .css.styles import RenderStyles, Styles +from .css.tokenize import IDENTIFIER from .message_pump import MessagePump from .timer import Timer @@ -40,10 +42,23 @@ from .screen import Screen from .widget import Widget +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: # pragma: no cover + from typing_extensions import TypeAlias + _re_identifier = re.compile(IDENTIFIER) +WalkMethod: TypeAlias = Literal["depth", "breadth"] + + class BadIdentifier(Exception): """raised by check_identifiers.""" @@ -617,11 +632,19 @@ def walk_children( filter_type: type[WalkType], *, with_self: bool = True, + method: WalkMethod = "depth", + reverse: bool = False, ) -> Iterable[WalkType]: ... @overload - def walk_children(self, *, with_self: bool = True) -> Iterable[DOMNode]: + def walk_children( + self, + *, + with_self: bool = True, + method: WalkMethod = "depth", + reverse: bool = False, + ) -> Iterable[DOMNode]: ... def walk_children( @@ -629,6 +652,8 @@ def walk_children( filter_type: type[WalkType] | None = None, *, with_self: bool = True, + method: WalkMethod = "depth", + reverse: bool = False, ) -> Iterable[DOMNode | WalkType]: """Generate descendant nodes. @@ -636,29 +661,60 @@ def walk_children( filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter. Defaults to None. with_self (bool, optional): Also yield self in addition to descendants. Defaults to True. + method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth". + reverse (bool, optional): Reverse the order (bottom up). Defaults to False Returns: Iterable[DOMNode | WalkType]: An iterable of nodes. """ - stack: list[Iterator[DOMNode]] = [iter(self.children)] - pop = stack.pop - push = stack.append - check_type = filter_type or DOMNode - - if with_self and isinstance(self, check_type): - yield self + def walk_depth_first() -> Iterable[DOMNode]: + """Walk the tree depth first (parents first).""" + stack: list[Iterator[DOMNode]] = [iter(self.children)] + pop = stack.pop + push = stack.append + check_type = filter_type or DOMNode - while stack: - node = next(stack[-1], None) - if node is None: - pop() - else: + if with_self and isinstance(self, check_type): + yield self + while stack: + node = next(stack[-1], None) + if node is None: + pop() + else: + if isinstance(node, check_type): + yield node + if node.children: + push(iter(node.children)) + + def walk_breadth_first() -> Iterable[DOMNode]: + """Walk the tree breadth first (children first).""" + queue: deque[DOMNode] = deque() + popleft = queue.popleft + extend = queue.extend + check_type = filter_type or DOMNode + + if with_self and isinstance(self, check_type): + yield self + extend(self.children) + while queue: + node = popleft() if isinstance(node, check_type): yield node - if node.children: - push(iter(node.children)) + extend(node.children) + + node_generator = ( + walk_depth_first() if method == "depth" else walk_breadth_first() + ) + + # We want a snapshot of the DOM at this point + # So that is doesn't change mid-walk + nodes = list(node_generator) + if reverse: + yield from reversed(nodes) + else: + yield from nodes def get_child(self, id: str) -> DOMNode: """Return the first child (immediate descendent) of this node with the given ID. diff --git a/src/textual/widget.py b/src/textual/widget.py index 6f190ec09a..23858b64e5 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -1857,10 +1857,6 @@ async def handle_key(self, event: events.Key) -> bool: await self.action(binding.action) return True - def _on_compose(self, event: events.Compose) -> None: - widgets = self.compose() - self.app.mount_all(widgets) - def _on_mount(self, event: events.Mount) -> None: widgets = self.compose() self.mount(*widgets) diff --git a/tests/test_dom.py b/tests/test_dom.py index e4254f6e55..5a713193a7 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -75,3 +75,65 @@ def test_validate(): node.remove_class("1") with pytest.raises(BadIdentifier): node.toggle_class("1") + + +@pytest.fixture +def search(): + """ + a + / \ + b c + / / \ + d e f + """ + a = DOMNode(id="a") + b = DOMNode(id="b") + c = DOMNode(id="c") + d = DOMNode(id="d") + e = DOMNode(id="e") + f = DOMNode(id="f") + + a._add_child(b) + a._add_child(c) + b._add_child(d) + c._add_child(e) + c._add_child(f) + + yield a + + +def test_walk_children_depth(search): + children = [ + node.id for node in search.walk_children(method="depth", with_self=False) + ] + assert children == ["b", "d", "c", "e", "f"] + + +def test_walk_children_with_self_depth(search): + children = [ + node.id for node in search.walk_children(method="depth", with_self=True) + ] + assert children == ["a", "b", "d", "c", "e", "f"] + + +def test_walk_children_breadth(search): + children = [ + node.id for node in search.walk_children(with_self=False, method="breadth") + ] + print(children) + assert children == ["b", "c", "d", "e", "f"] + + +def test_walk_children_with_self_breadth(search): + children = [ + node.id for node in search.walk_children(with_self=True, method="breadth") + ] + print(children) + assert children == ["a", "b", "c", "d", "e", "f"] + + children = [ + node.id + for node in search.walk_children(with_self=True, method="breadth", reverse=True) + ] + + assert children == ["f", "e", "d", "c", "b", "a"]