diff --git a/CHANGELOG.md b/CHANGELOG.md index 546efcc169..18b876a3cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- Updated `DataTable.get_cell` type hints to accept string keys https://github.com/Textualize/textual/issues/2586 +- Added `DataTable.get_cell_coordinate` method +- Added `DataTable.get_row_index` method https://github.com/Textualize/textual/issues/2587 +- Added `DataTable.get_column_index` method - Added can-focus pseudo-class to target widgets that may receive focus - Make `Markdown.update` optionally awaitable https://github.com/Textualize/textual/pull/2838 - Added `default` parameter to `DataTable.add_column` for populating existing rows https://github.com/Textualize/textual/pull/2836 diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index e08affcc0a..bfa80b32a2 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -772,7 +772,7 @@ def update_cell_at( row_key, column_key = self.coordinate_to_cell_key(coordinate) self.update_cell(row_key, column_key, value, update_width=update_width) - def get_cell(self, row_key: RowKey, column_key: ColumnKey) -> CellType: + def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellType: """Given a row key and column key, return the value of the corresponding cell. Args: @@ -805,6 +805,32 @@ def get_cell_at(self, coordinate: Coordinate) -> CellType: row_key, column_key = self.coordinate_to_cell_key(coordinate) return self.get_cell(row_key, column_key) + def get_cell_coordinate( + self, row_key: RowKey | str, column_key: Column | str + ) -> Coordinate: + """Given a row key and column key, return the corresponding cell coordinate. + + Args: + row_key: The row key of the cell. + column_key: The column key of the cell. + + Returns: + The current coordinate of the cell identified by the row and column keys. + + Raises: + CellDoesNotExist: If the specified cell does not exist. + """ + if ( + row_key not in self._row_locations + or column_key not in self._column_locations + ): + raise CellDoesNotExist( + f"No cell exists for row_key={row_key!r}, column_key={column_key!r}." + ) + row_index = self._row_locations.get(row_key) + column_index = self._column_locations.get(column_key) + return Coordinate(row_index, column_index) + def get_row(self, row_key: RowKey | str) -> list[CellType]: """Get the values from the row identified by the given row key. @@ -844,6 +870,22 @@ def get_row_at(self, row_index: int) -> list[CellType]: row_key = self._row_locations.get_key(row_index) return self.get_row(row_key) + def get_row_index(self, row_key: RowKey | str) -> int: + """Return the current index for the row identified by row_key. + + Args: + row_key: The row key to find the current index of. + + Returns: + The current index of the specified row key. + + Raises: + RowDoesNotExist: If the row key does not exist. + """ + if row_key not in self._row_locations: + raise RowDoesNotExist(f"No row exists for row_key={row_key!r}") + return self._row_locations.get(row_key) + def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: """Get the values from the column identified by the given column key. @@ -882,6 +924,22 @@ def get_column_at(self, column_index: int) -> Iterable[CellType]: column_key = self._column_locations.get_key(column_index) yield from self.get_column(column_key) + def get_column_index(self, column_key: ColumnKey | str) -> int: + """Return the current index for the column identified by column_key. + + Args: + column_key: The column key to find the current index of. + + Returns: + The current index of the specified column key. + + Raises: + ColumnDoesNotExist: If the column key does not exist. + """ + if column_key not in self._column_locations: + raise ColumnDoesNotExist(f"No column exists for column_key={column_key!r}") + return self._column_locations.get(column_key) + def _clear_caches(self) -> None: self._row_render_cache.clear() self._cell_render_cache.clear() diff --git a/tests/test_data_table.py b/tests/test_data_table.py index baa338b897..6c971eb87b 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -392,6 +392,46 @@ async def test_get_cell_invalid_column_key(): table.get_cell("R1", "INVALID_COLUMN") +async def test_get_cell_coordinate_returns_coordinate(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_column("Column2", key="C2") + table.add_column("Column3", key="C3") + table.add_row("ValR1C1", "ValR1C2", "ValR1C3", key="R1") + table.add_row("ValR2C1", "ValR2C2", "ValR2C3", key="R2") + table.add_row("ValR3C1", "ValR3C2", "ValR3C3", key="R3") + + assert table.get_cell_coordinate('R1', 'C1') == Coordinate(0, 0) + assert table.get_cell_coordinate('R2', 'C2') == Coordinate(1, 1) + assert table.get_cell_coordinate('R1', 'C3') == Coordinate(0, 2) + assert table.get_cell_coordinate('R3', 'C1') == Coordinate(2, 0) + assert table.get_cell_coordinate('R3', 'C2') == Coordinate(2, 1) + + +async def test_get_cell_coordinate_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + + with pytest.raises(CellDoesNotExist): + coordinate = table.get_cell_coordinate('INVALID_ROW', 'C1') + + +async def test_get_cell_coordinate_invalid_column_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + + with pytest.raises(CellDoesNotExist): + coordinate = table.get_cell_coordinate('R1', 'INVALID_COLUMN') + + async def test_get_cell_at_returns_value_at_cell(): app = DataTableApp() async with app.run_test(): @@ -465,6 +505,32 @@ async def test_get_row_at_invalid_index(index): table.get_row_at(index) +async def test_get_row_index_returns_index(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_column("Column2", key="C2") + table.add_row("ValR1C1", "ValR1C2", key="R1") + table.add_row("ValR2C1", "ValR2C2", key="R2") + table.add_row("ValR3C1", "ValR3C2", key="R3") + + assert table.get_row_index('R1') == 0 + assert table.get_row_index('R2') == 1 + assert table.get_row_index('R3') == 2 + + +async def test_get_row_index_invalid_row_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_row("TargetValue", key="R1") + + with pytest.raises(RowDoesNotExist): + index = table.get_row_index('InvalidRow') + + async def test_get_column(): app = DataTableApp() async with app.run_test(): @@ -509,6 +575,34 @@ async def test_get_column_at_invalid_index(index): with pytest.raises(ColumnDoesNotExist): list(table.get_column_at(index)) +async def test_get_column_index_returns_index(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_column("Column2", key="C2") + table.add_column("Column3", key="C3") + table.add_row("ValR1C1", "ValR1C2", "ValR1C3", key="R1") + table.add_row("ValR2C1", "ValR2C2", "ValR2C3", key="R2") + + assert table.get_column_index('C1') == 0 + assert table.get_column_index('C2') == 1 + assert table.get_column_index('C3') == 2 + + +async def test_get_column_index_invalid_column_key(): + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + table.add_column("Column1", key="C1") + table.add_column("Column2", key="C2") + table.add_column("Column3", key="C3") + table.add_row("TargetValue1", "TargetValue2", "TargetValue3", key="R1") + + with pytest.raises(ColumnDoesNotExist): + index = table.get_column_index('InvalidCol') + + async def test_update_cell_cell_exists(): app = DataTableApp()