diff --git a/CHANGELOG.md b/CHANGELOG.md index 4849a09bf5..76b63dc91f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - Widget scrolling methods (such as `Widget.scroll_home` and `Widget.scroll_end`) now perform the scroll after the next refresh https://github.com/Textualize/textual/issues/1774 +- Buttons no longer accept arbitrary renderables https://github.com/Textualize/textual/issues/1870 ### Fixed diff --git a/mypy.ini b/mypy.ini index c61bcf0ace..37236b0c3d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,3 +14,7 @@ ignore_missing_imports = True [mypy-ipywidgets.*] ignore_missing_imports = True + +[mypy-uvloop.*] +# Ignore missing imports for optional library that isn't listed as a dependency. +ignore_missing_imports = True diff --git a/src/textual/__init__.py b/src/textual/__init__.py index 0ffb3533c1..bef888ae18 100644 --- a/src/textual/__init__.py +++ b/src/textual/__init__.py @@ -64,7 +64,10 @@ def __call__(self, *args: object, **kwargs) -> None: if app.devtools is None or not app.devtools.is_connected: return - previous_frame = inspect.currentframe().f_back + current_frame = inspect.currentframe() + assert current_frame is not None + previous_frame = current_frame.f_back + assert previous_frame is not None caller = inspect.getframeinfo(previous_frame) _log = self._log or app._log diff --git a/src/textual/_callback.py b/src/textual/_callback.py index ee38014540..add75588be 100644 --- a/src/textual/_callback.py +++ b/src/textual/_callback.py @@ -58,6 +58,7 @@ async def invoke(callback: Callable, *params: object) -> Any: # In debug mode we will warn about callbacks that may be stuck def log_slow() -> None: """Log a message regarding a slow callback.""" + assert app is not None app.log.warning( f"Callback {callback} is still pending after {INVOKE_TIMEOUT_WARNING} seconds" ) diff --git a/src/textual/_compositor.py b/src/textual/_compositor.py index 1ee1464fb8..79080f59f5 100644 --- a/src/textual/_compositor.py +++ b/src/textual/_compositor.py @@ -14,7 +14,7 @@ from __future__ import annotations from operator import itemgetter -from typing import TYPE_CHECKING, Iterable, NamedTuple, cast +from typing import TYPE_CHECKING, Callable, Iterable, NamedTuple, cast import rich.repr from rich.console import Console, ConsoleOptions, RenderableType, RenderResult @@ -45,12 +45,23 @@ class ReflowResult(NamedTuple): class MapGeometry(NamedTuple): """Defines the absolute location of a Widget.""" - region: Region # The (screen) region occupied by the widget - order: tuple[tuple[int, ...], ...] # A tuple of ints defining the painting order - clip: Region # A region to clip the widget by (if a Widget is within a container) - virtual_size: Size # The virtual size (scrollable region) of a widget if it is a container - container_size: Size # The container size (area not occupied by scrollbars) - virtual_region: Region # The region relative to the container (but not necessarily visible) + region: Region + """The (screen) region occupied by the widget.""" + order: tuple[tuple[int, int, int], ...] + """Tuple of tuples defining the painting order of the widget. + + Each successive triple represents painting order information with regards to + ancestors in the DOM hierarchy and the last triple provides painting order + information for this specific widget. + """ + clip: Region + """A region to clip the widget by (if a Widget is within a container).""" + virtual_size: Size + """The virtual size (scrollable region) of a widget if it is a container.""" + container_size: Size + """The container size (area not occupied by scrollbars).""" + virtual_region: Region + """The region relative to the container (but not necessarily visible).""" @property def visible_region(self) -> Region: @@ -419,19 +430,23 @@ def add_widget( widget: Widget, virtual_region: Region, region: Region, - order: tuple[tuple[int, ...], ...], + order: tuple[tuple[int, int, int], ...], layer_order: int, clip: Region, visible: bool, - _MapGeometry=MapGeometry, + _MapGeometry: type[MapGeometry] = MapGeometry, ) -> None: """Called recursively to place a widget and its children in the map. Args: widget: The widget to add. + virtual_region: The Widget region relative to it's container. region: The region the widget will occupy. - order: A tuple of ints to define the order. + order: Painting order information. + layer_order: The order of the widget in its layer. clip: The clipping region (i.e. the viewport which contains it). + visible: Whether the widget should be visible by default. + This may be overriden by the CSS rule `visibility`. """ visibility = widget.styles.get_rule("visibility") if visibility is not None: @@ -501,11 +516,12 @@ def add_widget( ) widget_region = sub_region + placement_scroll_offset - widget_order = ( - *order, - get_layer_index(sub_widget.layer, 0), - z, - layer_order, + widget_order = order + ( + ( + get_layer_index(sub_widget.layer, 0), + z, + layer_order, + ), ) add_widget( @@ -560,7 +576,7 @@ def add_widget( root, size.region, size.region, - ((0,),), + ((0, 0, 0),), layer_order, size.region, True, @@ -818,11 +834,8 @@ def render(self, full: bool = False) -> RenderableType | None: # Maps each cut on to a list of segments cuts = self.cuts - # dict.fromkeys is a callable which takes a list of ints returns a dict which maps ints on to a list of Segments or None. - fromkeys = cast( - "Callable[[list[int]], dict[int, list[Segment] | None]]", dict.fromkeys - ) - # A mapping of cut index to a list of segments for each line + # dict.fromkeys is a callable which takes a list of ints returns a dict which maps ints onto a Segment or None. + fromkeys = cast("Callable[[list[int]], dict[int, Strip | None]]", dict.fromkeys) chops: list[dict[int, Strip | None]] chops = [fromkeys(cut_set[:-1]) for cut_set in cuts] diff --git a/src/textual/_doc.py b/src/textual/_doc.py index 3b4c9d08b5..391c85d128 100644 --- a/src/textual/_doc.py +++ b/src/textual/_doc.py @@ -4,7 +4,7 @@ import os import shlex from pathlib import Path -from typing import Iterable +from typing import Iterable, cast from textual._import_app import import_app from textual.app import App @@ -45,6 +45,7 @@ def format_svg(source, language, css_class, options, md, attrs, **kwargs) -> str import traceback traceback.print_exception(error) + return "" def take_svg_screenshot( @@ -82,6 +83,7 @@ def get_cache_key(app: App) -> str: hash = hashlib.md5() file_paths = [app_path] + app.css_path for path in file_paths: + assert path is not None with open(path, "rb") as source_file: hash.update(source_file.read()) hash.update(f"{press}-{title}-{terminal_size}".encode("utf-8")) @@ -105,10 +107,13 @@ async def auto_pilot(pilot: Pilot) -> None: app.exit(svg) - svg = app.run( - headless=True, - auto_pilot=auto_pilot, - size=terminal_size, + svg = cast( + str, + app.run( + headless=True, + auto_pilot=auto_pilot, + size=terminal_size, + ), ) if app_path is not None: diff --git a/src/textual/_styles_cache.py b/src/textual/_styles_cache.py index 1082f4c054..3cb4aebeb9 100644 --- a/src/textual/_styles_cache.py +++ b/src/textual/_styles_cache.py @@ -18,7 +18,7 @@ from .strip import Strip if TYPE_CHECKING: - from typing import TypeAlias + from typing_extensions import TypeAlias from .css.styles import StylesBase from .widget import Widget diff --git a/src/textual/_xterm_parser.py b/src/textual/_xterm_parser.py index 336da9af0a..a68da302c7 100644 --- a/src/textual/_xterm_parser.py +++ b/src/textual/_xterm_parser.py @@ -200,8 +200,8 @@ def reissue_sequence_as_keys(reissue_sequence: str) -> None: if not bracketed_paste: # Was it a pressed key event that we received? key_events = list(sequence_to_key_events(sequence)) - for event in key_events: - on_token(event) + for key_event in key_events: + on_token(key_event) if key_events: break # Or a mouse event? diff --git a/src/textual/actions.py b/src/textual/actions.py index dcb5f16e99..d85a0cacc2 100644 --- a/src/textual/actions.py +++ b/src/textual/actions.py @@ -3,6 +3,11 @@ import ast import re +from typing_extensions import Any, TypeAlias + +ActionParseResult: TypeAlias = "tuple[str, tuple[Any, ...]]" +"""An action is its name and the arbitrary tuple of its parameters.""" + class SkipAction(Exception): """Raise in an action to skip the action (and allow any parent bindings to run).""" @@ -15,7 +20,7 @@ class ActionError(Exception): re_action_params = re.compile(r"([\w\.]+)(\(.*?\))") -def parse(action: str) -> tuple[str, tuple[object, ...]]: +def parse(action: str) -> ActionParseResult: """Parses an action string. Args: diff --git a/src/textual/app.py b/src/textual/app.py index d2ed3e0c40..fa70671b03 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -57,7 +57,7 @@ from ._event_broker import NoHandler, extract_handler_actions from ._path import _make_path_object_relative from ._wait import wait_for_idle -from .actions import SkipAction +from .actions import ActionParseResult, SkipAction from .await_remove import AwaitRemove from .binding import Binding, Bindings from .css.query import NoMatches @@ -645,7 +645,7 @@ def _log( self, group: LogGroup, verbosity: LogVerbosity, - _textual_calling_frame: inspect.FrameInfo, + _textual_calling_frame: inspect.Traceback, *objects: Any, **kwargs, ) -> None: @@ -1605,9 +1605,8 @@ async def invoke_ready_callback() -> None: with redirect_stdout(redirector): # type: ignore await run_process_messages() else: - null_file = _NullFile() - with redirect_stderr(null_file): - with redirect_stdout(null_file): + with redirect_stderr(None): + with redirect_stdout(None): await run_process_messages() finally: @@ -1732,16 +1731,17 @@ def _register( if not widgets: return [] - new_widgets = list(widgets) - + widget_list: Iterable[Widget] if before is not None or after is not None: # There's a before or after, which means there's going to be an # insertion, so make it easier to get the new things in the # correct order. - new_widgets = reversed(new_widgets) + widget_list = reversed(widgets) + else: + widget_list = widgets apply_stylesheet = self.stylesheet.apply - for widget in new_widgets: + for widget in widget_list: if not isinstance(widget, Widget): raise AppError(f"Can't register {widget!r}; expected a Widget instance") if widget not in self._registry: @@ -1798,14 +1798,14 @@ def is_mounted(self, widget: Widget) -> bool: async def _close_all(self) -> None: """Close all message pumps.""" - # Close all screens on the stack - for screen in reversed(self._screen_stack): - if screen._running: - await self._prune_node(screen) + # Close all screens on the stack. + for stack_screen in reversed(self._screen_stack): + if stack_screen._running: + await self._prune_node(stack_screen) self._screen_stack.clear() - # Close pre-defined screens + # Close pre-defined screens. for screen in self.SCREENS.values(): if isinstance(screen, Screen) and screen._running: await self._prune_node(screen) @@ -1971,7 +1971,7 @@ async def on_event(self, event: events.Event) -> None: async def action( self, - action: str | tuple[str, tuple[str, ...]], + action: str | ActionParseResult, default_namespace: object | None = None, ) -> bool: """Perform an action. @@ -2069,7 +2069,7 @@ async def _broker_event( else: event.stop() if isinstance(action, (str, tuple)): - await self.action(action, default_namespace=default_namespace) + await self.action(action, default_namespace=default_namespace) # type: ignore[arg-type] elif callable(action): await action() else: @@ -2339,9 +2339,12 @@ def _end_update(self) -> None: def _init_uvloop() -> None: - """ - Import and install the `uvloop` asyncio policy, if available. + """Import and install the `uvloop` asyncio policy, if available. + This is done only once, even if the function is called multiple times. + + This is provided as a nicety for users that have `uvloop` installed independently + of Textual, as `uvloop` is not listed as a Textual dependency. """ global _uvloop_init_done @@ -2349,10 +2352,10 @@ def _init_uvloop() -> None: return try: - import uvloop + import uvloop # type: ignore[reportMissingImports] except ImportError: pass else: - uvloop.install() + uvloop.install() # type: ignore[reportUnknownMemberType] _uvloop_init_done = True diff --git a/src/textual/cli/previews/easing.py b/src/textual/cli/previews/easing.py index 204cd62463..70dc45758a 100644 --- a/src/textual/cli/previews/easing.py +++ b/src/textual/cli/previews/easing.py @@ -92,6 +92,7 @@ def _animation_complete(): target_position = ( END_POSITION if self.position == START_POSITION else START_POSITION ) + assert event.button.id is not None # Should be set to an easing function str. self.animate( "position", value=target_position, @@ -106,7 +107,7 @@ def watch_position(self, value: int): self.opacity_widget.styles.opacity = 1 - value / END_POSITION def on_input_changed(self, event: Input.Changed): - if event.sender.id == "duration-input": + if event.input.id == "duration-input": new_duration = _try_float(event.value) if new_duration is not None: self.duration = new_duration diff --git a/src/textual/css/_style_properties.py b/src/textual/css/_style_properties.py index f892ac4b9e..60389c3cae 100644 --- a/src/textual/css/_style_properties.py +++ b/src/textual/css/_style_properties.py @@ -528,8 +528,8 @@ def __set__(self, obj: StylesBase, spacing: SpacingDimensions | None): string (e.g. ``"blue on #f0f0f0"``). Raises: - ValueError: When the value is malformed, e.g. a ``tuple`` with a length that is - not 1, 2, or 4. + ValueError: When the value is malformed, + e.g. a ``tuple`` with a length that is not 1, 2, or 4. """ _rich_traceback_omit = True if spacing is None: @@ -543,7 +543,9 @@ def __set__(self, obj: StylesBase, spacing: SpacingDimensions | None): str(error), help_text=spacing_wrong_number_of_values_help_text( property_name=self.name, - num_values_supplied=len(spacing), + num_values_supplied=( + 1 if isinstance(spacing, int) else len(spacing) + ), context="inline", ), ) diff --git a/src/textual/css/parse.py b/src/textual/css/parse.py index d1c5f709b8..92dd6eacff 100644 --- a/src/textual/css/parse.py +++ b/src/textual/css/parse.py @@ -264,7 +264,7 @@ def substitute_references( iter_tokens = iter(tokens) - while tokens: + while True: token = next(iter_tokens, None) if token is None: break @@ -274,8 +274,7 @@ def substitute_references( while True: token = next(iter_tokens, None) - # TODO: Mypy error looks legit - if token.name == "whitespace": + if token is not None and token.name == "whitespace": yield token else: break diff --git a/src/textual/css/scalar_animation.py b/src/textual/css/scalar_animation.py index 935be134ac..cf541d298e 100644 --- a/src/textual/css/scalar_animation.py +++ b/src/textual/css/scalar_animation.py @@ -7,14 +7,14 @@ from .scalar import Scalar, ScalarOffset if TYPE_CHECKING: - from ..dom import DOMNode + from ..widget import Widget from .styles import StylesBase class ScalarAnimation(Animation): def __init__( self, - widget: DOMNode, + widget: Widget, styles: StylesBase, start_time: float, attribute: str, diff --git a/src/textual/css/styles.py b/src/textual/css/styles.py index 8d2ccf426a..53664ed1e6 100644 --- a/src/textual/css/styles.py +++ b/src/textual/css/styles.py @@ -335,6 +335,9 @@ def __textual_animation__( if not isinstance(value, (Scalar, ScalarOffset)): return None + from ..widget import Widget + + assert isinstance(self.node, Widget) return ScalarAnimation( self.node, self, @@ -581,7 +584,9 @@ def partial_rich_style(self) -> Style: @dataclass class Styles(StylesBase): node: DOMNode | None = None - _rules: RulesMap = field(default_factory=dict) + _rules: RulesMap = field( + default_factory=lambda: RulesMap() + ) # mypy won't be happy with `default_factory=RulesMap` _updates: int = 0 important: set[str] = field(default_factory=set) @@ -648,14 +653,14 @@ def reset(self) -> None: self._updates += 1 self._rules.clear() # type: ignore - def merge(self, other: Styles) -> None: + def merge(self, other: StylesBase) -> None: """Merge values from another Styles. Args: other: A Styles object. """ self._updates += 1 - self._rules.update(other._rules) + self._rules.update(other.get_rules()) def merge_rules(self, rules: RulesMap) -> None: self._updates += 1 @@ -1066,7 +1071,7 @@ def __rich_repr__(self) -> rich.repr.Result: def refresh(self, *, layout: bool = False, children: bool = False) -> None: self._inline_styles.refresh(layout=layout, children=children) - def merge(self, other: Styles) -> None: + def merge(self, other: StylesBase) -> None: """Merge values from another Styles. Args: diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index 1ebc1121e3..1c6c9de5f7 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -78,7 +78,7 @@ def __rich_console__( f"{path.absolute() if path else filename}:{line_no}:{col_no}" ) link_style = Style( - link=f"file://{path.absolute()}", + link=f"file://{path.absolute()}" if path else None, color="red", bold=True, italic=True, diff --git a/src/textual/demo.css b/src/textual/demo.css index 3fb8c7d719..9941d8ba51 100644 --- a/src/textual/demo.css +++ b/src/textual/demo.css @@ -1,4 +1,4 @@ - * { +* { transition: background 500ms in_out_cubic, color 500ms in_out_cubic; } @@ -125,7 +125,7 @@ DarkSwitch Switch { } -Screen > Container { +Screen>Container { height: 100%; overflow: hidden; } @@ -222,7 +222,7 @@ LoginForm { border: wide $background; } -LoginForm Button{ +LoginForm Button { margin: 0 1; width: 100%; } @@ -250,7 +250,7 @@ Window { max-height: 16; } -Window > Static { +Window>Static { width: auto; } diff --git a/src/textual/demo.py b/src/textual/demo.py index c80c662396..4fa0f32d60 100644 --- a/src/textual/demo.py +++ b/src/textual/demo.py @@ -205,7 +205,7 @@ def compose(self) -> ComposeResult: def on_mount(self) -> None: self.watch(self.app, "dark", self.on_dark_change, init=False) - def on_dark_change(self, dark: bool) -> None: + def on_dark_change(self) -> None: self.query_one(Switch).value = self.app.dark def on_switch_changed(self, event: Switch.Changed) -> None: @@ -302,7 +302,7 @@ def on_click(self) -> None: self.remove() -class DemoApp(App): +class DemoApp(App[None]): CSS_PATH = "demo.css" TITLE = "Textual Demo" BINDINGS = [ diff --git a/src/textual/devtools/client.py b/src/textual/devtools/client.py index 0ba5705529..3c8592ae61 100644 --- a/src/textual/devtools/client.py +++ b/src/textual/devtools/client.py @@ -35,7 +35,7 @@ class DevtoolsLog(NamedTuple): """ objects_or_string: tuple[Any, ...] | str - caller: inspect.FrameInfo + caller: inspect.Traceback class DevtoolsConsole(Console): diff --git a/src/textual/devtools/redirect_output.py b/src/textual/devtools/redirect_output.py index b79f935805..0e1f934b2e 100644 --- a/src/textual/devtools/redirect_output.py +++ b/src/textual/devtools/redirect_output.py @@ -39,7 +39,10 @@ def write(self, string: str) -> None: if not self.devtools.is_connected: return - previous_frame = inspect.currentframe().f_back + current_frame = inspect.currentframe() + assert current_frame is not None + previous_frame = current_frame.f_back + assert previous_frame is not None caller = inspect.getframeinfo(previous_frame) self._buffer.append(DevtoolsLog(string, caller=caller)) diff --git a/src/textual/devtools/service.py b/src/textual/devtools/service.py index 5e85d2f895..f3c3f19f77 100644 --- a/src/textual/devtools/service.py +++ b/src/textual/devtools/service.py @@ -5,10 +5,10 @@ import json import pickle from json import JSONDecodeError -from typing import Any, cast +from typing import Any import msgpack -from aiohttp import WSMessage, WSMsgType +from aiohttp import WSMsgType from aiohttp.abc import Request from aiohttp.web_ws import WebSocketResponse from rich.console import Console diff --git a/src/textual/dom.py b/src/textual/dom.py index 829b36796c..d7cfc4a2b0 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -24,7 +24,7 @@ from ._context import NoActiveAppError from ._node_list import NodeList from ._types import CallbackType -from .binding import Bindings, BindingType +from .binding import Binding, Bindings, BindingType from .color import BLACK, WHITE, Color from .css._error_tools import friendly_list from .css.constants import VALID_DISPLAY, VALID_VISIBILITY @@ -39,7 +39,7 @@ if TYPE_CHECKING: from .app import App - from .css.query import DOMQuery + from .css.query import DOMQuery, QueryType from .screen import Screen from .widget import Widget from typing_extensions import TypeAlias @@ -276,7 +276,7 @@ def _merge_bindings(cls) -> Bindings: base.__dict__.get("BINDINGS", []), ) ) - keys = {} + keys: dict[str, Binding] = {} for bindings_ in bindings: keys.update(bindings_.keys) return Bindings(keys.values()) @@ -357,7 +357,7 @@ def screen(self) -> "Screen": # Note that self.screen may not be the same as self.app.screen from .screen import Screen - node = self + node: MessagePump | None = self while node is not None and not isinstance(node, Screen): node = node._parent if not isinstance(node, Screen): @@ -771,19 +771,17 @@ def walk_children( nodes.reverse() return cast("list[DOMNode]", nodes) - ExpectType = TypeVar("ExpectType", bound="Widget") - @overload def query(self, selector: str | None) -> DOMQuery[Widget]: ... @overload - def query(self, selector: type[ExpectType]) -> DOMQuery[ExpectType]: + def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ... def query( - self, selector: str | type[ExpectType] | None = None - ) -> DOMQuery[Widget] | DOMQuery[ExpectType]: + self, selector: str | type[QueryType] | None = None + ) -> DOMQuery[Widget] | DOMQuery[QueryType]: """Get a DOM query matching a selector. Args: @@ -792,33 +790,31 @@ def query( Returns: A query object. """ - from .css.query import DOMQuery + from .css.query import DOMQuery, QueryType + from .widget import Widget - query: str | None if isinstance(selector, str) or selector is None: - query = selector + return DOMQuery[Widget](self, filter=selector) else: - query = selector.__name__ - - return DOMQuery(self, filter=query) + return DOMQuery[QueryType](self, filter=selector.__name__) @overload def query_one(self, selector: str) -> Widget: ... @overload - def query_one(self, selector: type[ExpectType]) -> ExpectType: + def query_one(self, selector: type[QueryType]) -> QueryType: ... @overload - def query_one(self, selector: str, expect_type: type[ExpectType]) -> ExpectType: + def query_one(self, selector: str, expect_type: type[QueryType]) -> QueryType: ... def query_one( self, - selector: str | type[ExpectType], - expect_type: type[ExpectType] | None = None, - ) -> ExpectType | Widget: + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> QueryType | Widget: """Get a single Widget matching the given selector or selector type. Args: diff --git a/src/textual/events.py b/src/textual/events.py index d4c1625b21..28887bec9e 100644 --- a/src/textual/events.py +++ b/src/textual/events.py @@ -5,7 +5,7 @@ import rich.repr from rich.style import Style -from ._types import MessageTarget +from ._types import CallbackType, MessageTarget from .geometry import Offset, Size from .keys import _get_key_aliases from .message import Message @@ -28,11 +28,7 @@ def __rich_repr__(self) -> rich.repr.Result: @rich.repr.auto class Callback(Event, bubble=False, verbose=True): - def __init__( - self, - sender: MessageTarget, - callback: Callable[[], Awaitable[None]], - ) -> None: + def __init__(self, sender: MessageTarget, callback: CallbackType) -> None: self.callback = callback super().__init__(sender) diff --git a/src/textual/keys.py b/src/textual/keys.py index 81e0aca381..1c5fe219d1 100644 --- a/src/textual/keys.py +++ b/src/textual/keys.py @@ -5,7 +5,7 @@ # Adapted from prompt toolkit https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/prompt_toolkit/keys.py -class Keys(str, Enum): +class Keys(str, Enum): # type: ignore[no-redef] """ List of keys for use in key bindings. @@ -13,7 +13,9 @@ class Keys(str, Enum): strings. """ - value: str + @property + def value(self) -> str: + return super().value Escape = "escape" # Also Control-[ ShiftEscape = "shift+escape" diff --git a/src/textual/screen.py b/src/textual/screen.py index db82a80de7..dbe5b3a80e 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -12,7 +12,7 @@ from ._types import CallbackType from .css.match import match from .css.parse import parse_selectors -from .dom import DOMNode +from .css.query import QueryType from .geometry import Offset, Region, Size from .reactive import Reactive from .renderables.blank import Blank @@ -169,7 +169,7 @@ def focus_chain(self) -> list[Widget]: return widgets def _move_focus( - self, direction: int = 0, selector: str | type[DOMNode.ExpectType] = "*" + self, direction: int = 0, selector: str | type[QueryType] = "*" ) -> Widget | None: """Move the focus in the given direction. @@ -230,9 +230,7 @@ def _move_focus( return self.focused - def focus_next( - self, selector: str | type[DOMNode.ExpectType] = "*" - ) -> Widget | None: + def focus_next(self, selector: str | type[QueryType] = "*") -> Widget | None: """Focus the next widget, optionally filtered by a CSS selector. If no widget is currently focused, this will focus the first focusable widget. @@ -249,9 +247,7 @@ def focus_next( """ return self._move_focus(1, selector) - def focus_previous( - self, selector: str | type[DOMNode.ExpectType] = "*" - ) -> Widget | None: + def focus_previous(self, selector: str | type[QueryType] = "*") -> Widget | None: """Focus the previous widget, optionally filtered by a CSS selector. If no widget is currently focused, this will focus the first focusable widget. diff --git a/src/textual/scrollbar.py b/src/textual/scrollbar.py index f00dbccb27..3e9acf913c 100644 --- a/src/textual/scrollbar.py +++ b/src/textual/scrollbar.py @@ -300,13 +300,13 @@ def _on_enter(self, event: events.Enter) -> None: def _on_leave(self, event: events.Leave) -> None: self.mouse_over = False - async def action_scroll_down(self) -> None: - await self.post_message( + def action_scroll_down(self) -> None: + self.post_message_no_wait( ScrollDown(self) if self.vertical else ScrollRight(self) ) - async def action_scroll_up(self) -> None: - await self.post_message(ScrollUp(self) if self.vertical else ScrollLeft(self)) + def action_scroll_up(self) -> None: + self.post_message_no_wait(ScrollUp(self) if self.vertical else ScrollLeft(self)) def action_grab(self) -> None: self.capture_mouse() diff --git a/src/textual/widget.py b/src/textual/widget.py index 4d87351a61..b3b0eb49bc 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -65,6 +65,7 @@ if TYPE_CHECKING: from .app import App, ComposeResult + from .message_pump import MessagePump from .scrollbar import ( ScrollBar, ScrollBarCorner, @@ -443,23 +444,27 @@ def get_widget_by_id( self, id: str, expect_type: type[ExpectType] | None = None ) -> ExpectType | Widget: """Return the first descendant widget with the given ID. + Performs a depth-first search rooted at this widget. Args: - id: The ID to search for in the subtree + id: The ID to search for in the subtree. expect_type: Require the object be of the supplied type, or None for any type. - Defaults to None. Returns: The first descendant encountered with this ID. Raises: - NoMatches: if no children could be found for this ID + NoMatches: if no children could be found for this ID. WrongType: if the wrong type was found. """ - for child in walk_depth_first(self): + # We use Widget as a filter_type so that the inferred type of child is Widget. + for child in walk_depth_first(self, filter_type=Widget): try: - return child.get_child_by_id(id, expect_type=expect_type) + if expect_type is None: + return child.get_child_by_id(id) + else: + return child.get_child_by_id(id, expect_type=expect_type) except NoMatches: pass except WrongType as exc: @@ -729,7 +734,9 @@ def _to_widget(child: int | Widget, called: str) -> Widget: # Ensure the child and target are widgets. child = _to_widget(child, "move") - target = _to_widget(before if after is None else after, "move towards") + target = _to_widget( + cast("int | Widget", before if after is None else after), "move towards" + ) # At this point we should know what we're moving, and it should be a # child; where we're moving it to, which should be within the child @@ -2275,7 +2282,7 @@ def get_pseudo_classes(self) -> Iterable[str]: Names of the pseudo classes. """ - node = self + node: MessagePump | None = self while isinstance(node, Widget): if node.disabled: yield "disabled" @@ -2322,7 +2329,9 @@ def post_render(self, renderable: RenderableType) -> ConsoleRenderable: renderable.justify = text_justify renderable = _Styled( - renderable, self.rich_style, self.link_style if self.auto_links else None + cast(ConsoleRenderable, renderable), + self.rich_style, + self.link_style if self.auto_links else None, ) return renderable @@ -2524,7 +2533,7 @@ def refresh( self.check_idle() def remove(self) -> AwaitRemove: - """Remove the Widget from the DOM (effectively deleting it) + """Remove the Widget from the DOM (effectively deleting it). Returns: An awaitable object that waits for the widget to be removed. @@ -2537,16 +2546,16 @@ def render(self) -> RenderableType: """Get renderable for widget. Returns: - Any renderable + Any renderable. """ - render = "" if self.is_container else self.css_identifier_styled + render: Text | str = "" if self.is_container else self.css_identifier_styled return render def _render(self) -> ConsoleRenderable | RichCast: """Get renderable, promoting str to text as required. Returns: - A renderable + A renderable. """ renderable = self.render() if isinstance(renderable, str): diff --git a/src/textual/widgets/_button.py b/src/textual/widgets/_button.py index 156ba3f194..f8b22a9694 100644 --- a/src/textual/widgets/_button.py +++ b/src/textual/widgets/_button.py @@ -4,7 +4,6 @@ from typing import cast import rich.repr -from rich.console import RenderableType from rich.text import Text, TextType from typing_extensions import Literal @@ -145,7 +144,7 @@ class Button(Static, can_focus=True): ACTIVE_EFFECT_DURATION = 0.3 """When buttons are clicked they get the `-active` class for this duration (in seconds)""" - label: reactive[RenderableType] = reactive[RenderableType]("") + label: reactive[TextType] = reactive[TextType]("") """The text label that appears within the button.""" variant = reactive("default") @@ -209,15 +208,14 @@ def watch_variant(self, old_variant: str, variant: str): self.remove_class(f"-{old_variant}") self.add_class(f"-{variant}") - def validate_label(self, label: RenderableType) -> RenderableType: + def validate_label(self, label: TextType) -> TextType: """Parse markup for self.label""" if isinstance(label, str): return Text.from_markup(label) return label - def render(self) -> RenderableType: - label = self.label.copy() - label = Text.assemble(" ", label, " ") + def render(self) -> TextType: + label = Text.assemble(" ", self.label, " ") label.stylize(self.text_style) return label diff --git a/src/textual/widgets/_footer.py b/src/textual/widgets/_footer.py index 076b9445a8..fe859a72d8 100644 --- a/src/textual/widgets/_footer.py +++ b/src/textual/widgets/_footer.py @@ -66,7 +66,7 @@ async def watch_highlight_key(self, value) -> None: self.refresh() def on_mount(self) -> None: - self.watch(self.screen, "focused", self._focus_changed) + self.watch(self.screen, "focused", self._focus_changed) # type: ignore[arg-type] def _focus_changed(self, focused: Widget | None) -> None: self._key_text = None diff --git a/src/textual/widgets/_header.py b/src/textual/widgets/_header.py index 12e07e4c1b..43de8d4acb 100644 --- a/src/textual/widgets/_header.py +++ b/src/textual/widgets/_header.py @@ -133,5 +133,5 @@ def set_title(title: str) -> None: def set_sub_title(sub_title: str) -> None: self.query_one(HeaderTitle).sub_text = sub_title - self.watch(self.app, "title", set_title) - self.watch(self.app, "sub_title", set_sub_title) + self.watch(self.app, "title", set_title) # type: ignore[arg-type] + self.watch(self.app, "sub_title", set_sub_title) # type: ignore[arg-type] diff --git a/src/textual/widgets/_list_item.py b/src/textual/widgets/_list_item.py index e9222a7aac..522305e009 100644 --- a/src/textual/widgets/_list_item.py +++ b/src/textual/widgets/_list_item.py @@ -31,7 +31,7 @@ class ListItem(Widget, can_focus=False): class _ChildClicked(Message): """For informing with the parent ListView that we were clicked""" - pass + sender: "ListItem" def on_click(self, event: events.Click) -> None: self.post_message_no_wait(self._ChildClicked(self)) diff --git a/src/textual/widgets/_list_view.py b/src/textual/widgets/_list_view.py index af41589dbd..e24ecf31b5 100644 --- a/src/textual/widgets/_list_view.py +++ b/src/textual/widgets/_list_view.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import ClassVar +from typing import ClassVar, Optional from textual.await_remove import AwaitRemove from textual.binding import Binding, BindingType @@ -8,7 +8,7 @@ from textual.geometry import clamp from textual.message import Message from textual.reactive import reactive -from textual.widget import AwaitMount +from textual.widget import AwaitMount, Widget from textual.widgets._list_item import ListItem @@ -35,7 +35,7 @@ class ListView(Vertical, can_focus=True, can_focus_children=False): | down | Move the cursor down. | """ - index = reactive(0, always_update=True) + index = reactive[Optional[int]](0, always_update=True) class Highlighted(Message, bubble=True): """Posted when the highlighted item changes. @@ -96,10 +96,12 @@ def on_mount(self) -> None: @property def highlighted_child(self) -> ListItem | None: """The currently highlighted ListItem, or None if nothing is highlighted.""" - if self.index is None: + if self.index is not None and 0 <= self.index < len(self._nodes): + list_item = self._nodes[self.index] + assert isinstance(list_item, ListItem) + return list_item + else: return None - elif 0 <= self.index < len(self._nodes): - return self._nodes[self.index] def validate_index(self, index: int | None) -> int | None: """Clamp the index to the valid range, or set to None if there's nothing to highlight. @@ -129,9 +131,13 @@ def watch_index(self, old_index: int, new_index: int) -> None: """Updates the highlighting when the index changes.""" if self._is_valid_index(old_index): old_child = self._nodes[old_index] + assert isinstance(old_child, ListItem) old_child.highlighted = False + + new_child: Widget | None if self._is_valid_index(new_index): new_child = self._nodes[new_index] + assert isinstance(new_child, ListItem) new_child.highlighted = True else: new_child = None @@ -168,14 +174,22 @@ def clear(self) -> AwaitRemove: def action_select_cursor(self) -> None: """Select the current item in the list.""" selected_child = self.highlighted_child + if selected_child is None: + return self.post_message_no_wait(self.Selected(self, selected_child)) def action_cursor_down(self) -> None: """Highlight the next item in the list.""" + if self.index is None: + self.index = 0 + return self.index += 1 def action_cursor_up(self) -> None: """Highlight the previous item in the list.""" + if self.index is None: + self.index = 0 + return self.index -= 1 def on_list_item__child_clicked(self, event: ListItem._ChildClicked) -> None: