diff --git a/CHANGELOG.md b/CHANGELOG.md index f852ea6d5a..64be93e2f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added `TreeNode.parent` -- a read-only property for accessing a node's parent https://github.com/Textualize/textual/issues/1397 - Added public `TreeNode` label access via `TreeNode.label` https://github.com/Textualize/textual/issues/1396 +- Added read-only public access to the children of a `TreeNode` via `TreeNode.children` https://github.com/Textualize/textual/issues/1398 ### Changed diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py index 39b63a0376..9fb5639eca 100644 --- a/src/textual/widgets/_tree.py +++ b/src/textual/widgets/_tree.py @@ -14,6 +14,7 @@ from .._segment_tools import line_crop, line_pad from .._types import MessageTarget from .._typing import TypeAlias +from .._collections import ImmutableSequence from ..binding import Binding from ..geometry import Region, Size, clamp from ..message import Message @@ -53,6 +54,10 @@ def _get_guide_width(self, guide_depth: int, show_root: bool) -> int: return guides +class TreeNodes(ImmutableSequence["TreeNode[TreeDataType]"]): + """An immutable collection of `TreeNode`.""" + + @rich.repr.auto class TreeNode(Generic[TreeDataType]): """An object that represents a "node" in a tree control.""" @@ -91,6 +96,11 @@ def _reset(self) -> None: self._selected_ = False self._updates += 1 + @property + def children(self) -> TreeNodes[TreeDataType]: + """TreeNodes[TreeDataType]: The child nodes of a TreeNode.""" + return TreeNodes(self._children) + @property def line(self) -> int: """int: Get the line number for this node, or -1 if it is not displayed.""" diff --git a/tests/tree/test_tree_node_children.py b/tests/tree/test_tree_node_children.py new file mode 100644 index 0000000000..22df664dba --- /dev/null +++ b/tests/tree/test_tree_node_children.py @@ -0,0 +1,32 @@ +import pytest +from textual.widgets import Tree, TreeNode + +def label_of(node: TreeNode[None]): + """Get the label of a node. + + TODO: This is just a helper function to reduce the number of type + errors, which can and will be remove once this code is merged with a + version of main that also has the TreeNode.label PR merged. + """ + return str(node._label) + + +def test_tree_node_children() -> None: + """A node's children property should act like an immutable list.""" + CHILDREN=23 + tree = Tree[None]("Root") + for child in range(CHILDREN): + tree.root.add(str(child)) + assert len(tree.root.children)==CHILDREN + for child in range(CHILDREN): + assert label_of(tree.root.children[child]) == str(child) + assert label_of(tree.root.children[0]) == "0" + assert label_of(tree.root.children[-1]) == str(CHILDREN-1) + assert [label_of(node) for node in tree.root.children] == [str(n) for n in range(CHILDREN)] + assert [label_of(node) for node in tree.root.children[:2]] == [str(n) for n in range(2)] + with pytest.raises(TypeError): + tree.root.children[0] = tree.root.children[1] + with pytest.raises(TypeError): + del tree.root.children[0] + with pytest.raises(TypeError): + del tree.root.children[0:2]