Skip to content

Commit

Permalink
Persist color channels in rasterize() (#544)
Browse files Browse the repository at this point in the history
* Add tests for `rasterize` channel persistence

* In `rasterize`: add channels back to `transformed_data`

Otherwise named channels will get lost during rasterize, which
interferes with referencing them by name when plotting.

* added release note

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* replacing try except with schema check

---------

Co-authored-by: Luca Marconato <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 25, 2024
1 parent db6a55b commit 2a38051
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning][].
### Minor

- Removed `pygeos` dependency @omsai #545
- Channel coordinate annotations on images now persist through `rasterize()` @clwgg #544

## [0.1.2] - 2024-03-30

Expand Down
4 changes: 3 additions & 1 deletion src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def _(
**kwargs,
)
assert isinstance(transformed_dask, DaskArray)
transformed_data = schema.parse(transformed_dask, dims=xdata.dims) # type: ignore[call-arg,arg-type]
channels = xdata.coords["c"].values if schema in (Image2DModel, Image3DModel) else None
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg,arg-type]

if target_coordinate_system != "global":
remove_transformation(transformed_data, "global")

Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def blobs_annotating_element(name: BlobsTypes) -> SpatialData:
instance_id = _get_unique_label_values_as_index(sdata[name]).tolist()
else:
index = sdata[name].index
instance_id = index.compute().tolist() if isinstance(index, dask.dataframe.core.Index) else index.tolist()
instance_id = index.compute().tolist() if isinstance(index, dask.dataframe.core.Index) else index.tolist()
n = len(instance_id)
new_table = AnnData(shape=(n, 0), obs={"region": [name for _ in range(n)], "instance_id": instance_id})
new_table = TableModel.parse(new_table, region=name, region_key="region", instance_key="instance_id")
Expand Down
3 changes: 3 additions & 0 deletions tests/core/operations/test_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def _get_data_of_largest_scale(raster):
**kwargs,
)

if "c" in raster.coords:
assert np.array_equal(raster.coords["c"].values, result.coords["c"].values)

result_data = _get_data_of_largest_scale(result)
n_equal = result_data[tuple(slices)] == 1
ratio = np.sum(n_equal) / np.prod(n_equal.shape)
Expand Down

0 comments on commit 2a38051

Please sign in to comment.