Skip to content

Commit

Permalink
Merge pull request #902 from Textualize/depth-first
Browse files Browse the repository at this point in the history
depth first search
  • Loading branch information
willmcgugan authored Oct 14, 2022
2 parents 88447b7 + 14316fd commit ef7ecc0
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 31 deletions.
9 changes: 9 additions & 0 deletions sandbox/will/screens_focus.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Focusable {
padding: 3 6;
background: blue 20%;
}

Focusable :focus {
border: solid red;
}

20 changes: 20 additions & 0 deletions sandbox/will/screens_focus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from textual.app import App, ComposeResult
from textual.widgets import Static, Footer


class Focusable(Static, can_focus=True):
pass


class ScreensFocusApp(App):
def compose(self) -> ComposeResult:
yield Focusable("App - one")
yield Focusable("App - two")
yield Focusable("App - three")
yield Focusable("App - four")
yield Footer()


app = ScreensFocusApp(css_path="screens_focus.css")
if __name__ == "__main__":
app.run()
11 changes: 8 additions & 3 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .design import ColorSystem
from .devtools.client import DevtoolsClient, DevtoolsConnectionError, DevtoolsLog
from .devtools.redirect_output import StdoutRedirector
from .dom import DOMNode
from .dom import DOMNode, NoScreen
from .driver import Driver
from .drivers.headless_driver import HeadlessDriver
from .features import FeatureFlag, parse_features
Expand Down Expand Up @@ -1142,7 +1142,10 @@ def _unregister(self, widget: Widget) -> None:
Args:
widget (Widget): A Widget to unregister
"""
widget.screen._reset_focus(widget)
try:
widget.screen._reset_focus(widget)
except NoScreen:
pass

if isinstance(widget._parent, Widget):
widget._parent.children._remove(widget)
Expand Down Expand Up @@ -1394,7 +1397,9 @@ async def _on_remove(self, event: events.Remove) -> None:
if parent is not None:
parent.refresh(layout=True)

remove_widgets = list(widget.walk_children(Widget, with_self=True))
remove_widgets = widget.walk_children(
Widget, with_self=True, method="depth", reverse=True
)
for child in remove_widgets:
self._unregister(child)
for child in remove_widgets:
Expand Down
2 changes: 1 addition & 1 deletion src/textual/css/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def refresh(self, *, layout: bool = False, children: bool = False) -> None:
if self.node is not None:
self.node.refresh(layout=layout)
if children:
for child in self.node.walk_children(with_self=False):
for child in self.node.walk_children(with_self=False, reverse=True):
child.refresh(layout=layout)

def reset(self) -> None:
Expand Down
102 changes: 79 additions & 23 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from inspect import getfile
import re
import sys
from collections import deque
from inspect import getfile
from typing import (
cast,
TYPE_CHECKING,
ClassVar,
Iterable,
Iterator,
Type,
overload,
TypeVar,
TYPE_CHECKING,
cast,
overload,
)

import rich.repr
Expand All @@ -23,14 +25,14 @@
from ._context import NoActiveAppError
from ._node_list import NodeList
from .binding import Bindings, BindingType
from .color import Color, WHITE, BLACK
from .color import BLACK, WHITE, Color
from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
from .css.errors import StyleValueError, DeclarationError
from .css.errors import DeclarationError, StyleValueError
from .css.parse import parse_declarations
from .css.styles import Styles, RenderStyles
from .css.tokenize import IDENTIFIER
from .css.query import NoMatches
from .css.styles import RenderStyles, Styles
from .css.tokenize import IDENTIFIER
from .message_pump import MessagePump
from .timer import Timer

Expand All @@ -40,10 +42,23 @@
from .screen import Screen
from .widget import Widget

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

if sys.version_info >= (3, 10):
from typing import TypeAlias
else: # pragma: no cover
from typing_extensions import TypeAlias


_re_identifier = re.compile(IDENTIFIER)


WalkMethod: TypeAlias = Literal["depth", "breadth"]


class BadIdentifier(Exception):
"""raised by check_identifiers."""

Expand Down Expand Up @@ -617,48 +632,89 @@ def walk_children(
filter_type: type[WalkType],
*,
with_self: bool = True,
method: WalkMethod = "depth",
reverse: bool = False,
) -> Iterable[WalkType]:
...

@overload
def walk_children(self, *, with_self: bool = True) -> Iterable[DOMNode]:
def walk_children(
self,
*,
with_self: bool = True,
method: WalkMethod = "depth",
reverse: bool = False,
) -> Iterable[DOMNode]:
...

def walk_children(
self,
filter_type: type[WalkType] | None = None,
*,
with_self: bool = True,
method: WalkMethod = "depth",
reverse: bool = False,
) -> Iterable[DOMNode | WalkType]:
"""Generate descendant nodes.
Args:
filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter.
Defaults to None.
with_self (bool, optional): Also yield self in addition to descendants. Defaults to True.
method (Literal["breadth", "depth"], optional): One of "depth" or "breadth". Defaults to "depth".
reverse (bool, optional): Reverse the order (bottom up). Defaults to False
Returns:
Iterable[DOMNode | WalkType]: An iterable of nodes.
"""

stack: list[Iterator[DOMNode]] = [iter(self.children)]
pop = stack.pop
push = stack.append
check_type = filter_type or DOMNode

if with_self and isinstance(self, check_type):
yield self
def walk_depth_first() -> Iterable[DOMNode]:
"""Walk the tree depth first (parents first)."""
stack: list[Iterator[DOMNode]] = [iter(self.children)]
pop = stack.pop
push = stack.append
check_type = filter_type or DOMNode

while stack:
node = next(stack[-1], None)
if node is None:
pop()
else:
if with_self and isinstance(self, check_type):
yield self
while stack:
node = next(stack[-1], None)
if node is None:
pop()
else:
if isinstance(node, check_type):
yield node
if node.children:
push(iter(node.children))

def walk_breadth_first() -> Iterable[DOMNode]:
"""Walk the tree breadth first (children first)."""
queue: deque[DOMNode] = deque()
popleft = queue.popleft
extend = queue.extend
check_type = filter_type or DOMNode

if with_self and isinstance(self, check_type):
yield self
extend(self.children)
while queue:
node = popleft()
if isinstance(node, check_type):
yield node
if node.children:
push(iter(node.children))
extend(node.children)

node_generator = (
walk_depth_first() if method == "depth" else walk_breadth_first()
)

# We want a snapshot of the DOM at this point
# So that is doesn't change mid-walk
nodes = list(node_generator)
if reverse:
yield from reversed(nodes)
else:
yield from nodes

def get_child(self, id: str) -> DOMNode:
"""Return the first child (immediate descendent) of this node with the given ID.
Expand Down
4 changes: 0 additions & 4 deletions src/textual/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,10 +1857,6 @@ async def handle_key(self, event: events.Key) -> bool:
await self.action(binding.action)
return True

def _on_compose(self, event: events.Compose) -> None:
widgets = self.compose()
self.app.mount_all(widgets)

def _on_mount(self, event: events.Mount) -> None:
widgets = self.compose()
self.mount(*widgets)
Expand Down
62 changes: 62 additions & 0 deletions tests/test_dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,65 @@ def test_validate():
node.remove_class("1")
with pytest.raises(BadIdentifier):
node.toggle_class("1")


@pytest.fixture
def search():
"""
a
/ \
b c
/ / \
d e f
"""
a = DOMNode(id="a")
b = DOMNode(id="b")
c = DOMNode(id="c")
d = DOMNode(id="d")
e = DOMNode(id="e")
f = DOMNode(id="f")

a._add_child(b)
a._add_child(c)
b._add_child(d)
c._add_child(e)
c._add_child(f)

yield a


def test_walk_children_depth(search):
children = [
node.id for node in search.walk_children(method="depth", with_self=False)
]
assert children == ["b", "d", "c", "e", "f"]


def test_walk_children_with_self_depth(search):
children = [
node.id for node in search.walk_children(method="depth", with_self=True)
]
assert children == ["a", "b", "d", "c", "e", "f"]


def test_walk_children_breadth(search):
children = [
node.id for node in search.walk_children(with_self=False, method="breadth")
]
print(children)
assert children == ["b", "c", "d", "e", "f"]


def test_walk_children_with_self_breadth(search):
children = [
node.id for node in search.walk_children(with_self=True, method="breadth")
]
print(children)
assert children == ["a", "b", "c", "d", "e", "f"]

children = [
node.id
for node in search.walk_children(with_self=True, method="breadth", reverse=True)
]

assert children == ["f", "e", "d", "c", "b", "a"]

0 comments on commit ef7ecc0

Please sign in to comment.