Skip to content

Commit

Permalink
Rasterize shapes (#566)
Browse files Browse the repository at this point in the history
* minimal rasterize structure for shapes

* test rasterize

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename squares -> box

* fix mypy and use python>=3.10 typings

* fix docstrings and support points

* save label_index_to_category mapping dict

* remove column drop, since using view of the data

* fix docs

* reworking docstring

* add tests points

* wip

* remove instance_key_as_default_value_key

* fix tests pre release install

* refactoring

* tests for alternative calls of rasterize(); refactor relational_query

* fix type checking

* using to polygons

* labels now accept table annotations

* testing rasterize for points and shapes when index is str

* testing rasterize() for SpatialData objects

* changelog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
3 people authored Jun 9, 2024
1 parent 62a6440 commit 592561f
Show file tree
Hide file tree
Showing 8 changed files with 770 additions and 232 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning][].
### Added

- Added operation: `to_polygons()` @quentinblampey #560
- Extended `rasterize()` to support all the data types @quentinblampey #566

### Minor

Expand Down
20 changes: 20 additions & 0 deletions src/spatialdata/_core/operations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from multiscale_spatial_image import MultiscaleSpatialImage
from spatial_image import SpatialImage

from spatialdata.models import SpatialElement

if TYPE_CHECKING:
from spatialdata._core.spatialdata import SpatialData

Expand Down Expand Up @@ -134,3 +136,21 @@ def transform_to_data_extent(
for k, v in sdata.tables.items():
sdata_to_return_elements[k] = v.copy()
return SpatialData.from_elements_dict(sdata_to_return_elements)


def _parse_element(
element: str | SpatialElement, sdata: SpatialData | None, element_var_name: str, sdata_var_name: str
) -> SpatialElement:
if not ((sdata is not None and isinstance(element, str)) ^ (not isinstance(element, str))):
raise ValueError(
f"To specify the {element_var_name!r} SpatialElement, please do one of the following: "
f"- either pass a SpatialElement to the {element_var_name!r} parameter (and keep "
f"`{sdata_var_name}` = None);"
f"- either `{sdata_var_name}` needs to be a SpatialData object, and {element_var_name!r} needs "
f"to be the string name of the element."
)
if sdata is not None:
assert isinstance(element, str)
return sdata[element]
assert element is not None
return element
31 changes: 9 additions & 22 deletions src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from spatial_image import SpatialImage
from xrspatial import zonal_stats

from spatialdata._core.operations._utils import _parse_element
from spatialdata._core.operations.transform import transform
from spatialdata._core.query.relational_query import get_values
from spatialdata._core.spatialdata import SpatialData
Expand All @@ -26,7 +27,6 @@
Labels2DModel,
PointsModel,
ShapesModel,
SpatialElement,
TableModel,
get_model,
)
Expand All @@ -35,22 +35,6 @@
__all__ = ["aggregate"]


def _parse_element(element: str | SpatialElement, sdata: SpatialData | None, str_for_exception: str) -> SpatialElement:
if not ((sdata is not None and isinstance(element, str)) ^ (not isinstance(element, str))):
raise ValueError(
f"To specify the {str_for_exception!r} SpatialElement, please do one of the following: "
f"- either pass a SpatialElement to the {str_for_exception!r} parameter (and keep "
f"`{str_for_exception}_sdata` = None);"
f"- either `{str_for_exception}_sdata` needs to be a SpatialData object, and {str_for_exception!r} needs "
f"to be the string name of the element."
)
if sdata is not None:
assert isinstance(element, str)
return sdata[element]
assert element is not None
return element


def aggregate(
values: ddf.DataFrame | gpd.GeoDataFrame | SpatialImage | MultiscaleSpatialImage | str,
by: gpd.GeoDataFrame | SpatialImage | MultiscaleSpatialImage | str,
Expand Down Expand Up @@ -90,8 +74,9 @@ def aggregate(
The key can be:
- the name of a column(s) in the dataframe (Dask `DataFrame` for points or `GeoDataFrame` for shapes);
- the name of obs column(s) in the associated `AnnData` table (for shapes and labels);
- the name of a var(s), referring to the column(s) of the X matrix in the table (for shapes and labels).
- the name of obs column(s) in the associated `AnnData` table (for points, shapes and labels);
- the name of a var(s), referring to the column(s) of the X matrix in the table (for points, shapes and
labels).
If nothing is passed here, it defaults to the equivalent of a column of ones.
Defaults to `FEATURE_KEY` for points (if present).
Expand Down Expand Up @@ -127,7 +112,7 @@ def aggregate(
Whether to deepcopy the shapes in the returned `SpatialData` object. If the shapes are large (e.g. large
multiscale labels), you may consider disabling the deepcopy to use a lazy Dask representation.
table_name
The table optionally containing the value_key and the name of the table in the returned `SpatialData` object.
The table optionally containing the `value_key` and the name of the table in the returned `SpatialData` object.
buffer_resolution
Resolution parameter to pass to the of the .buffer() method to convert circles to polygons. A higher value
results in a more accurate representation of the circle, but also in a more complex polygon and computation.
Expand All @@ -154,8 +139,10 @@ def aggregate(
to a large memory usage. This Github issue https://github.com/scverse/spatialdata/issues/210 keeps track of the
changes required to address this behavior.
"""
values_ = _parse_element(element=values, sdata=values_sdata, str_for_exception="values")
by_ = _parse_element(element=by, sdata=by_sdata, str_for_exception="by")
values_ = _parse_element(
element=values, sdata=values_sdata, element_var_name="values", sdata_var_name="values_sdata"
)
by_ = _parse_element(element=by, sdata=by_sdata, element_var_name="by", sdata_var_name="by_sdata")

if values_ is by_:
# this case breaks the groupy aggregation in _aggregate_shapes(), probably a non relevant edge case so
Expand Down
Loading

0 comments on commit 592561f

Please sign in to comment.