Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI (dask pinning, test with rasterize, mypy) #809

Merged
merged 5 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"anndata>=0.9.1",
"click",
"dask-image",
"dask>=2024.4.1",
"dask>=2024.4.1,<=2024.11.2",
"fsspec",
"geopandas>=0.14",
"multiscale_spatial_image>=2.0.2",
Expand Down
23 changes: 18 additions & 5 deletions src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ def rasterize(
The table optionally containing the `value_key` and the name of the table in the returned `SpatialData` object.
Must be `None` when `data` is a `SpatialData` object, otherwise it assumes the default value of `'table'`.
return_regions_as_labels
By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels and shapes as
labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Points and images are always returned
as images, and multiscale raster data is always returned as single-scale data.
By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels, shapes and points
as labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Images are always returned as images,
and multiscale raster data is always returned as single-scale data.
agg_func
Available only when rasterizing points and shapes. A reduction function from datashader (its name, or a
`Callable`). See the notes for more details on the default behavior.
Expand All @@ -234,6 +234,11 @@ def rasterize(
into a `DataArray` (not a `DataTree`). So if a `SpatialData` object with elements is passed, a `SpatialData` object
with single-scale images and labels will be returned.

When `return_regions_as_labels` is `True`, the returned `DataArray` object will have an attribute called
`label_index_to_category` that maps the label index to the category name. You can access it via
`returned_data.attrs["label_index_to_category"]`. The returned labels will start from 1 (0 is reserved for the
background), and will be contiguous.

Notes
-----
For images and labels, the parameters `value_key`, `table_name`, `agg_func`, and `return_single_channel` are not
Expand Down Expand Up @@ -587,7 +592,7 @@ def rasterize_images_labels(
)
assert isinstance(transformed_dask, DaskArray)
channels = xdata.coords["c"].values if schema in (Image2DModel, Image3DModel) else None
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg,arg-type]
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg]

if target_coordinate_system != "global":
remove_transformation(transformed_data, "global")
Expand Down Expand Up @@ -650,7 +655,7 @@ def rasterize_shapes_points(
if value_key is not None:
kwargs = {"sdata": sdata, "element_name": element_name} if element_name is not None else {"element": data}
data[VALUES_COLUMN] = get_values(value_key, table_name=table_name, **kwargs).iloc[:, 0] # type: ignore[arg-type, union-attr]
elif isinstance(data, GeoDataFrame):
elif isinstance(data, GeoDataFrame) or isinstance(data, DaskDataFrame) and return_regions_as_labels is True:
value_key = VALUES_COLUMN
data[VALUES_COLUMN] = data.index.astype("category")
else:
Expand Down Expand Up @@ -706,6 +711,14 @@ def rasterize_shapes_points(
agg = agg.fillna(0)

if return_regions_as_labels:
if label_index_to_category is not None:
max_label = next(iter(reversed(label_index_to_category.keys())))
else:
max_label = int(agg.max().values)
max_uint16 = np.iinfo(np.uint16).max
if max_label > max_uint16:
raise ValueError(f"Maximum label index is {max_label}. Values higher than {max_uint16} are not supported.")
agg = agg.astype(np.uint16)
return Labels2DModel.parse(agg, transformations=transformations)

agg = agg.expand_dims(dim={"c": 1}).transpose("c", "y", "x")
Expand Down
14 changes: 7 additions & 7 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _transform_raster(
c_shape: tuple[int, ...]
c_shape = (data.shape[0],) if "c" in axes else ()
new_spatial_shape = tuple(
int(np.max(new_v[:, i]) - np.min(new_v[:, i])) for i in range(len(c_shape), n_spatial_dims + len(c_shape)) # type: ignore[operator]
int(np.max(new_v[:, i]) - np.min(new_v[:, i])) for i in range(len(c_shape), n_spatial_dims + len(c_shape))
)
output_shape = c_shape + new_spatial_shape
translation_vector = np.min(new_v[:, :-1], axis=0)
Expand Down Expand Up @@ -86,8 +86,8 @@ def _transform_raster(
# min_y_inverse = np.min(new_v_inverse[:, 1])

if "c" in axes:
plt.imshow(da.moveaxis(transformed_dask, 0, 2), origin="lower", alpha=0.5) # type: ignore[attr-defined]
plt.imshow(da.moveaxis(im, 0, 2), origin="lower", alpha=0.5) # type: ignore[attr-defined]
plt.imshow(da.moveaxis(transformed_dask, 0, 2), origin="lower", alpha=0.5)
plt.imshow(da.moveaxis(im, 0, 2), origin="lower", alpha=0.5)
else:
plt.imshow(transformed_dask, origin="lower", alpha=0.5)
plt.imshow(im, origin="lower", alpha=0.5)
Expand Down Expand Up @@ -322,7 +322,7 @@ def _(
)
c_coords = data.indexes["c"].values if "c" in data.indexes else None
# mypy thinks that schema could be ShapesModel, PointsModel, ...
transformed_data = schema.parse(transformed_dask, dims=axes, c_coords=c_coords) # type: ignore[call-arg,arg-type]
transformed_data = schema.parse(transformed_dask, dims=axes, c_coords=c_coords) # type: ignore[call-arg]
assert isinstance(transformed_data, DataArray)
old_transformations = get_transformation(data, get_all=True)
assert isinstance(old_transformations, dict)
Expand Down Expand Up @@ -448,7 +448,7 @@ def _(
for ax in axes:
indices = xtransformed["dim"] == ax
new_ax = xtransformed[:, indices]
transformed[ax] = new_ax.data.flatten() # type: ignore[attr-defined]
transformed[ax] = new_ax.data.flatten()

old_transformations = get_transformation(data, get_all=True)
assert isinstance(old_transformations, dict)
Expand Down Expand Up @@ -481,9 +481,9 @@ def _(
)
# TODO: nitpick, mypy expects a listof literals and here we have a list of strings.
# I ignored but we may want to fix this
affine = transformation.to_affine(axes, axes) # type: ignore[arg-type]
affine = transformation.to_affine(axes, axes)
matrix = affine.matrix
shapely_notation = matrix[:-1, :-1].ravel().tolist() + matrix[:-1, -1].tolist()
shapely_notation = matrix[:-1, :-1].ravel().tolist() + matrix[:-1, -1].tolist() # type: ignore[operator]
transformed_geometry = data.geometry.affine_transform(shapely_notation)
transformed_data = data.copy(deep=True)
transformed_data.attrs[TRANSFORM_KEY] = {DEFAULT_COORDINATE_SYSTEM: Identity()}
Expand Down
8 changes: 6 additions & 2 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _filter_table_by_elements(
# some instances have not a corresponding row in the table
instances = np.setdiff1d(instances, n0)
assert np.sum(to_keep) == len(instances)
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist()))
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist())) # type: ignore[type-var]
table_df = pd.DataFrame({instance_key: table.obs[instance_key], "position": np.arange(len(instances))})
merged = pd.merge(table_df, pd.DataFrame(index=instances), left_on=instance_key, right_index=True, how="right")
matched_positions = merged["position"].to_numpy()
Expand Down Expand Up @@ -467,7 +467,11 @@ def _left_join_spatialelement_table(
)
continue

joined_indices = joined_indices.dropna() if joined_indices is not None else None
if joined_indices is not None:
joined_indices = joined_indices.dropna()
# if nan were present, the dtype would have been changed to float
if joined_indices.dtype == float:
joined_indices = joined_indices.astype(int)
joined_table = table[joined_indices, :].copy() if joined_indices is not None else None
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)

Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,8 @@ def _(
bounding_box_mask = _bounding_box_mask_points(
points=points_query_coordinate_system,
axes=axes,
min_coordinate=min_c,
max_coordinate=max_c,
min_coordinate=min_c, # type: ignore[arg-type]
max_coordinate=max_c, # type: ignore[arg-type]
)
if len(bounding_box_mask) == 1:
bounding_box_mask = bounding_box_mask[0]
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

import zarr
from dask.dataframe import DataFrame as DaskDataFrame # type: ignore[attr-defined]
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import read_parquet
from ome_zarr.format import Format

Expand Down
14 changes: 5 additions & 9 deletions src/spatialdata/_types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from typing import Any

import numpy as np
from xarray import DataArray, DataTree

__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"]

try:
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float64]
IntArrayLike = NDArray[np.int64] # or any np.integer
from numpy.typing import DTypeLike, NDArray

except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
IntArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc, assignment]
ArrayLike = NDArray[np.floating[Any]]
IntArrayLike = NDArray[np.integer[Any]]

Raster_T = DataArray | DataTree
ColorLike = tuple[float, ...] | str
2 changes: 1 addition & 1 deletion src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]:
others = list(data.dims)
others.remove(axis)
# mypy (luca's pycharm config) can't see the isclose method of dask array
s = da.isclose(data.sum(dim=others), 0) # type: ignore[attr-defined]
s = da.isclose(data.sum(dim=others), 0)
# TODO: rewrite this to use dask array; can't get it to work with it
x = s.compute()
non_zero = np.where(x == 0)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(
**dict(rasterize_kwargs),
)
if rasterize
else bounding_box_query # type: ignore[assignment]
else bounding_box_query
)
self._return = self._get_return(return_annotations, table_name)
self.transform = transform
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _image_blobs(
masks = []
for i in range(n_channels):
mask = self._generate_blobs(length=length, seed=i)
mask = (mask - mask.min()) / np.ptp(mask) # type: ignore[attr-defined]
mask = (mask - mask.min()) / np.ptp(mask)
masks.append(mask)

x = np.stack(masks, axis=0)
Expand Down
12 changes: 6 additions & 6 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def parse(
else:
if len(set(dims).symmetric_difference(cls.dims.dims)) > 0:
raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.")
_reindex = lambda d: dims.index(d) # type: ignore[union-attr]
_reindex = lambda d: dims.index(d)
else:
raise ValueError(f"Unsupported data type: {type(data)}.")

Expand Down Expand Up @@ -717,7 +717,7 @@ def _(
stacklevel=2,
)
if isinstance(data, pd.DataFrame):
table: DaskDataFrame = dd.from_pandas( # type: ignore[attr-defined]
table: DaskDataFrame = dd.from_pandas(
pd.DataFrame(data[[coordinates[ax] for ax in axes]].to_numpy(), columns=axes, index=data.index),
# we need to pass sort=True also when the index is sorted to ensure that the divisions are computed
sort=sort,
Expand All @@ -731,9 +731,9 @@ def _(
data[feature_key].astype(str).astype("category"),
sort=sort,
**kwargs,
) # type: ignore[attr-defined]
)
table[feature_key] = feature_categ
elif isinstance(data, dd.DataFrame): # type: ignore[attr-defined]
elif isinstance(data, dd.DataFrame):
table = data[[coordinates[ax] for ax in axes]]
table.columns = axes
if feature_key is not None:
Expand Down Expand Up @@ -774,7 +774,7 @@ def _add_metadata_and_validate(
instance_key: str | None = None,
transformations: MappingToCoordinateSystem_t | None = None,
) -> DaskDataFrame:
assert isinstance(data, dd.DataFrame) # type: ignore[attr-defined]
assert isinstance(data, dd.DataFrame)
if feature_key is not None or instance_key is not None:
data.attrs[ATTRS_KEY] = {}
if feature_key is not None:
Expand All @@ -797,7 +797,7 @@ def _add_metadata_and_validate(
_parse_transformations(data, transformations)
cls.validate(data)
# false positive with the PyCharm mypy plugin
return data # type: ignore[no-any-return]
return data


class TableModel:
Expand Down
24 changes: 12 additions & 12 deletions src/spatialdata/transformations/ngff/ngff_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(
self.affine = self._parse_list_into_array(affine)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
assert isinstance(d["affine"], list)
last_row = [[0.0] * (len(d["affine"][0]) - 1) + [1.0]]
return cls(d["affine"] + last_row)
Expand Down Expand Up @@ -340,7 +340,7 @@ def transform_points(self, points: ArrayLike) -> ArrayLike:
self._validate_transform_points_shapes(len(input_axes), points.shape)
p = np.vstack([points.T, np.ones(points.shape[0])])
q = self.affine @ p
return q[: len(output_axes), :].T # type: ignore[no-any-return]
return q[: len(output_axes), :].T

def to_affine(self) -> "NgffAffine":
return NgffAffine(
Expand Down Expand Up @@ -411,7 +411,7 @@ def __init__(

# TODO: remove type: ignore[valid-type] when https://github.com/python/mypy/pull/14041 is merged
@classmethod
def _from_dict(cls, _: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, _: Transformation_t) -> Self:
return cls()

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -478,7 +478,7 @@ def __init__(
self.map_axis = map_axis

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls(d["mapAxis"])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -569,7 +569,7 @@ def __init__(
self.translation = self._parse_list_into_array(translation)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls(d["translation"])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -636,7 +636,7 @@ def __init__(
self.scale = self._parse_list_into_array(scale)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls(d["scale"])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -705,13 +705,13 @@ def __init__(
self.rotation = self._parse_list_into_array(rotation)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
x = d["rotation"]
n = len(x)
r = math.sqrt(n)
assert n == int(r * r)
m = np.array(x).reshape((int(r), int(r))).tolist()
return cls(m)
return cls(m) # type: ignore[arg-type]

def to_dict(self) -> Transformation_t:
d = {
Expand Down Expand Up @@ -802,7 +802,7 @@ def __init__(
self.output_coordinate_system = cs

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls([NgffBaseTransformation.from_dict(t) for t in d["transformations"]])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -941,7 +941,7 @@ def to_affine(self) -> NgffAffine:
for t in self.transformations:
latest_output_cs, input_cs, output_cs = NgffSequence._inferring_cs_pre_action(t, latest_output_cs)
a = t.to_affine()
composed = a.affine @ composed
composed = a.affine @ composed # type: ignore[assignment]
NgffSequence._inferring_cs_post_action(t, input_cs, output_cs)
if output_axes != latest_output_cs.axes_names:
raise ValueError(
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def __init__(
self.transformations = transformations

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls([NgffBaseTransformation.from_dict(t) for t in d["transformations"]])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def transform_points(self, points: ArrayLike) -> ArrayLike:
input_columns_stacked: ArrayLike = np.stack(input_columns, axis=1)
output_columns_t = t.transform_points(input_columns_stacked)
for ax, col in zip(t.output_coordinate_system.axes_names, output_columns_t.T, strict=True):
output_columns[ax] = col
output_columns[ax] = col # type: ignore[assignment]
output: ArrayLike = np.stack([output_columns[ax] for ax in output_axes], axis=1)
return output

Expand Down
Loading
Loading