Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sc 40 issue cluster intensity #79

Merged
merged 7 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,391 changes: 711 additions & 680 deletions docs/tutorials/general/FlowSOM_for_pixel_and_cell_clustering.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/harpy/image/_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def rasterize(
output_layer: str,
out_shape: tuple[int, int] | None = None, # output shape in y, x.
chunks: int | None = None,
client: Client = None,
client: Client | None = None,
scale_factors: ScaleFactors_t | None = None,
overwrite: bool = False,
) -> SpatialData:
Expand Down
59 changes: 51 additions & 8 deletions src/harpy/image/pixel_clustering/_clustering.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from __future__ import annotations

import uuid
from collections.abc import Iterable

import dask.array as da
import numpy as np
import pandas as pd
from anndata import AnnData
from dask.array import Array
from dask.distributed import Client
from numpy.typing import NDArray
from spatialdata import SpatialData
from spatialdata import SpatialData, read_zarr
from spatialdata.models import Image3DModel
from spatialdata.models.models import ScaleFactors_t
from spatialdata.transformations import get_transformation

from harpy.image._image import _get_spatial_element, add_labels_layer
from harpy.utils._keys import _INSTANCE_KEY, _REGION_KEY, _SPATIAL, ClusteringKey
from harpy.utils.pylogger import get_pylogger
from harpy.utils.utils import _get_uint_dtype

log = get_pylogger(__name__)

Expand All @@ -39,6 +43,8 @@ def flowsom(
random_state: int = 100,
chunks: str | int | tuple[int, ...] | None = None,
scale_factors: ScaleFactors_t | None = None,
client: Client | None = None,
persist_intermediate: bool = True,
overwrite: bool = False,
**kwargs, # keyword arguments passed to _flowsom
) -> tuple[SpatialData, fs.FlowSOM, pd.Series]:
Expand Down Expand Up @@ -67,9 +73,18 @@ def flowsom(
random_state
A random state for reproducibility of the clustering and sampling.
chunks
Chunk sizes for processing. If provided as a tuple, it should contain chunk sizes for `c`, `(z)`, `y`, `x`.
Chunk sizes used for flowsom inference step on `img_layer`. If provided as a tuple, it should contain chunk sizes for `c`, `(z)`, `y`, `x`.
scale_factors
Scale factors to apply for multiscale
client
A Dask `Client` instance. If specified, during inference, the trained `fs.FlowSOM` model will be scattered (`client.scatter(...)`).
This reduces the size of the task graph and can improve performance by minimizing data transfer overhead during computation.
If not specified, Dask will use the default scheduler as configured on your system (e.g., single-threaded, multithreaded, or a global client if one is running).
persist_intermediate
If set to `True` will persit intermediate computation in memory. If `img_layer`, or one of the elements in `img_layer` is large, this could lead to increased ram usage.
Set to `False` to write to intermediate zarr store instead, which will reduce ram usage, but will increase computation time slightly.
We advice to set `persist_intermediate` to `True`, as it will only persist an array of dimension `(2,z,y,x)`, of dtype `numpy.uint8`.
Ignored if `sdata` is not backed by a zarr store.
overwrite
If True, overwrites the `output_layer_cluster` and/or `output_layer_metacluster` if it already exists in `sdata`.
**kwargs
Expand Down Expand Up @@ -132,9 +147,6 @@ def _fix_name(layer: str | Iterable[str]):
_array_dim == arr.ndim
), "Image layers specified via parameter `img_layer` should all have same number of dimensions."

if chunks is not None:
arr = arr.rechunk(chunks)

to_squeeze = False
if arr.ndim == 3:
# add trivial z dimension for 2D case
Expand Down Expand Up @@ -169,24 +181,50 @@ def _fix_name(layer: str | Iterable[str]):
# 3D case, save z,y,x position
adata.obsm[_SPATIAL] = arr_sampled[:, -3:]

_, fsom = _flowsom(adata, n_clusters=n_clusters, seed=random_state, **kwargs)
xdim = kwargs.pop("xdim", 10)
ydim = kwargs.pop("ydim", 10)
dtype = _get_uint_dtype(value=xdim * ydim)
_, fsom = _flowsom(adata, n_clusters=n_clusters, seed=random_state, xdim=xdim, ydim=ydim, **kwargs)

if client is not None:
fsom_future = client.scatter(fsom)
else:
fsom_future = fsom

assert len(img_layer) == len(_arr_list)
# 2) apply fsom on all data
for i, _array in enumerate(_arr_list):
if chunks is not None:
if to_squeeze:
# fix chunks to account for fact that we added trivial z-dimension
if isinstance(chunks, Iterable) and not isinstance(chunks, str):
chunks = (chunks[0], 1, chunks[1], chunks[2])
_array = _array.rechunk(chunks)
output_chunks = ((2,), _array.chunks[1], _array.chunks[2], _array.chunks[3])

# predict flowsom clusters
_labels_flowsom = da.map_blocks(
_predict_flowsom_clusters_chunk,
_array, # can also be chunked in c dimension, drop_axis and new_axis take care of this
dtype=np.uint32,
dtype=dtype,
chunks=output_chunks,
drop_axis=0,
new_axis=0,
fsom=fsom,
fsom=fsom_future,
)

# write to intermediate zarr slot or persist, otherwise dask will run the flowsom inference two times (once for clusters, once for metaclusters),
# once for each time we call add_labels_layer.
if sdata.is_backed() and not persist_intermediate:
se_intermediate = Image3DModel.parse(_labels_flowsom)
_labels_flowsom_name = f"labels_flowsom_{uuid.uuid4()}"
sdata.images[_labels_flowsom_name] = se_intermediate
sdata.write_element(_labels_flowsom_name)
sdata = read_zarr(sdata.path)
_labels_flowsom = _get_spatial_element(sdata, layer=_labels_flowsom_name).data
else:
_labels_flowsom = _labels_flowsom.persist()

_labels_flowsom_clusters, _labels_flowsom_metaclusters = _labels_flowsom

# save the predicted clusters and metaclusters as a labels layer
Expand All @@ -208,6 +246,10 @@ def _fix_name(layer: str | Iterable[str]):
overwrite=overwrite,
)

if sdata.is_backed() and not persist_intermediate:
del sdata[_labels_flowsom_name]
sdata.delete_element_from_disk(element_name=_labels_flowsom_name)

# TODO decide on fix in flowsom to let clusters count from 1.
# fsom cluster ID's count from 0, while labels layer cluster ID's count from 1.

Expand Down Expand Up @@ -294,6 +336,7 @@ def _remove_nan_columns(array):
final_array.shape[0],
size=num_samples,
replace=False,
chunks=num_samples, # indices can not be multi-chunk, see https://github.com/dask/dask/blob/a9396a913c33de1d5966df9cc1901fd70107c99b/dask/array/random.py#L896
)

if remove_nan_columns:
Expand Down
2 changes: 1 addition & 1 deletion src/harpy/image/segmentation/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def add_grid_labels_layer(
grid_type: str = "hexagon", # can be either "hexagon" or "square".
offset: tuple[int, int] = (0, 0), # we recommend setting a non-zero offset via a translation.
chunks: int | None = None,
client: Client = None,
client: Client | None = None,
transformations: MappingToCoordinateSystem_t | None = None,
scale_factors: ScaleFactors_t | None = None,
overwrite: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/harpy/table/pixel_clustering/_cluster_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def cluster_intensity(
channels
Specifies the channels to be included in the intensity calculation.
chunks
Chunk sizes for processing. If provided as a tuple, it should contain chunk sizes for `c`, `(z)`, `y`, `x`.
Chunk sizes for processing. If provided as a `tuple`, it should contain chunk sizes for `c`, `(z)`, `y`, `x`.
overwrite
If True, overwrites the `output_layer` if it already exists in `sdata`.

Expand Down
6 changes: 4 additions & 2 deletions src/harpy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,10 @@ def _get_uint_dtype(value: int) -> str:
max_uint64 = np.iinfo(np.uint64).max
max_uint32 = np.iinfo(np.uint32).max
max_uint16 = np.iinfo(np.uint16).max

if max_uint16 >= value:
max_uint8 = np.iinfo(np.uint8).max
if max_uint8 >= value:
dtype = "uint8"
elif max_uint16 >= value:
dtype = "uint16"
elif max_uint32 >= value:
dtype = "uint32"
Expand Down
Loading