Skip to content

Commit

Permalink
Fix CI (dask pinning, test with rasterize, mypy) (#809)
Browse files Browse the repository at this point in the history
passing tests and pre-commits
  • Loading branch information
LucaMarconato authored Dec 16, 2024
1 parent 02bc276 commit 803a66e
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"anndata>=0.9.1",
"click",
"dask-image",
"dask>=2024.4.1",
"dask>=2024.4.1,<=2024.11.2",
"fsspec",
"geopandas>=0.14",
"multiscale_spatial_image>=2.0.2",
Expand Down
23 changes: 18 additions & 5 deletions src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ def rasterize(
The table optionally containing the `value_key` and the name of the table in the returned `SpatialData` object.
Must be `None` when `data` is a `SpatialData` object, otherwise it assumes the default value of `'table'`.
return_regions_as_labels
By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels and shapes as
labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Points and images are always returned
as images, and multiscale raster data is always returned as single-scale data.
By default, single-scale images of shape `(c, y, x)` are returned. If `True`, returns labels, shapes and points
as labels of shape `(y, x)` as opposed to an image of shape `(c, y, x)`. Images are always returned as images,
and multiscale raster data is always returned as single-scale data.
agg_func
Available only when rasterizing points and shapes. A reduction function from datashader (its name, or a
`Callable`). See the notes for more details on the default behavior.
Expand All @@ -234,6 +234,11 @@ def rasterize(
into a `DataArray` (not a `DataTree`). So if a `SpatialData` object with elements is passed, a `SpatialData` object
with single-scale images and labels will be returned.
When `return_regions_as_labels` is `True`, the returned `DataArray` object will have an attribute called
`label_index_to_category` that maps the label index to the category name. You can access it via
`returned_data.attrs["label_index_to_category"]`. The returned labels will start from 1 (0 is reserved for the
background), and will be contiguous.
Notes
-----
For images and labels, the parameters `value_key`, `table_name`, `agg_func`, and `return_single_channel` are not
Expand Down Expand Up @@ -587,7 +592,7 @@ def rasterize_images_labels(
)
assert isinstance(transformed_dask, DaskArray)
channels = xdata.coords["c"].values if schema in (Image2DModel, Image3DModel) else None
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg,arg-type]
transformed_data = schema.parse(transformed_dask, dims=xdata.dims, c_coords=channels) # type: ignore[call-arg]

if target_coordinate_system != "global":
remove_transformation(transformed_data, "global")
Expand Down Expand Up @@ -650,7 +655,7 @@ def rasterize_shapes_points(
if value_key is not None:
kwargs = {"sdata": sdata, "element_name": element_name} if element_name is not None else {"element": data}
data[VALUES_COLUMN] = get_values(value_key, table_name=table_name, **kwargs).iloc[:, 0] # type: ignore[arg-type, union-attr]
elif isinstance(data, GeoDataFrame):
elif isinstance(data, GeoDataFrame) or isinstance(data, DaskDataFrame) and return_regions_as_labels is True:
value_key = VALUES_COLUMN
data[VALUES_COLUMN] = data.index.astype("category")
else:
Expand Down Expand Up @@ -706,6 +711,14 @@ def rasterize_shapes_points(
agg = agg.fillna(0)

if return_regions_as_labels:
if label_index_to_category is not None:
max_label = next(iter(reversed(label_index_to_category.keys())))
else:
max_label = int(agg.max().values)
max_uint16 = np.iinfo(np.uint16).max
if max_label > max_uint16:
raise ValueError(f"Maximum label index is {max_label}. Values higher than {max_uint16} are not supported.")
agg = agg.astype(np.uint16)
return Labels2DModel.parse(agg, transformations=transformations)

agg = agg.expand_dims(dim={"c": 1}).transpose("c", "y", "x")
Expand Down
14 changes: 7 additions & 7 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _transform_raster(
c_shape: tuple[int, ...]
c_shape = (data.shape[0],) if "c" in axes else ()
new_spatial_shape = tuple(
int(np.max(new_v[:, i]) - np.min(new_v[:, i])) for i in range(len(c_shape), n_spatial_dims + len(c_shape)) # type: ignore[operator]
int(np.max(new_v[:, i]) - np.min(new_v[:, i])) for i in range(len(c_shape), n_spatial_dims + len(c_shape))
)
output_shape = c_shape + new_spatial_shape
translation_vector = np.min(new_v[:, :-1], axis=0)
Expand Down Expand Up @@ -86,8 +86,8 @@ def _transform_raster(
# min_y_inverse = np.min(new_v_inverse[:, 1])

if "c" in axes:
plt.imshow(da.moveaxis(transformed_dask, 0, 2), origin="lower", alpha=0.5) # type: ignore[attr-defined]
plt.imshow(da.moveaxis(im, 0, 2), origin="lower", alpha=0.5) # type: ignore[attr-defined]
plt.imshow(da.moveaxis(transformed_dask, 0, 2), origin="lower", alpha=0.5)
plt.imshow(da.moveaxis(im, 0, 2), origin="lower", alpha=0.5)
else:
plt.imshow(transformed_dask, origin="lower", alpha=0.5)
plt.imshow(im, origin="lower", alpha=0.5)
Expand Down Expand Up @@ -322,7 +322,7 @@ def _(
)
c_coords = data.indexes["c"].values if "c" in data.indexes else None
# mypy thinks that schema could be ShapesModel, PointsModel, ...
transformed_data = schema.parse(transformed_dask, dims=axes, c_coords=c_coords) # type: ignore[call-arg,arg-type]
transformed_data = schema.parse(transformed_dask, dims=axes, c_coords=c_coords) # type: ignore[call-arg]
assert isinstance(transformed_data, DataArray)
old_transformations = get_transformation(data, get_all=True)
assert isinstance(old_transformations, dict)
Expand Down Expand Up @@ -448,7 +448,7 @@ def _(
for ax in axes:
indices = xtransformed["dim"] == ax
new_ax = xtransformed[:, indices]
transformed[ax] = new_ax.data.flatten() # type: ignore[attr-defined]
transformed[ax] = new_ax.data.flatten()

old_transformations = get_transformation(data, get_all=True)
assert isinstance(old_transformations, dict)
Expand Down Expand Up @@ -481,9 +481,9 @@ def _(
)
# TODO: nitpick, mypy expects a listof literals and here we have a list of strings.
# I ignored but we may want to fix this
affine = transformation.to_affine(axes, axes) # type: ignore[arg-type]
affine = transformation.to_affine(axes, axes)
matrix = affine.matrix
shapely_notation = matrix[:-1, :-1].ravel().tolist() + matrix[:-1, -1].tolist()
shapely_notation = matrix[:-1, :-1].ravel().tolist() + matrix[:-1, -1].tolist() # type: ignore[operator]
transformed_geometry = data.geometry.affine_transform(shapely_notation)
transformed_data = data.copy(deep=True)
transformed_data.attrs[TRANSFORM_KEY] = {DEFAULT_COORDINATE_SYSTEM: Identity()}
Expand Down
8 changes: 6 additions & 2 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _filter_table_by_elements(
# some instances have not a corresponding row in the table
instances = np.setdiff1d(instances, n0)
assert np.sum(to_keep) == len(instances)
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist()))
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist())) # type: ignore[type-var]
table_df = pd.DataFrame({instance_key: table.obs[instance_key], "position": np.arange(len(instances))})
merged = pd.merge(table_df, pd.DataFrame(index=instances), left_on=instance_key, right_index=True, how="right")
matched_positions = merged["position"].to_numpy()
Expand Down Expand Up @@ -467,7 +467,11 @@ def _left_join_spatialelement_table(
)
continue

joined_indices = joined_indices.dropna() if joined_indices is not None else None
if joined_indices is not None:
joined_indices = joined_indices.dropna()
# if nan were present, the dtype would have been changed to float
if joined_indices.dtype == float:
joined_indices = joined_indices.astype(int)
joined_table = table[joined_indices, :].copy() if joined_indices is not None else None
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)

Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,8 @@ def _(
bounding_box_mask = _bounding_box_mask_points(
points=points_query_coordinate_system,
axes=axes,
min_coordinate=min_c,
max_coordinate=max_c,
min_coordinate=min_c, # type: ignore[arg-type]
max_coordinate=max_c, # type: ignore[arg-type]
)
if len(bounding_box_mask) == 1:
bounding_box_mask = bounding_box_mask[0]
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

import zarr
from dask.dataframe import DataFrame as DaskDataFrame # type: ignore[attr-defined]
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import read_parquet
from ome_zarr.format import Format

Expand Down
14 changes: 5 additions & 9 deletions src/spatialdata/_types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from typing import Any

import numpy as np
from xarray import DataArray, DataTree

__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"]

try:
from numpy.typing import DTypeLike, NDArray

ArrayLike = NDArray[np.float64]
IntArrayLike = NDArray[np.int64] # or any np.integer
from numpy.typing import DTypeLike, NDArray

except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
IntArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc, assignment]
ArrayLike = NDArray[np.floating[Any]]
IntArrayLike = NDArray[np.integer[Any]]

Raster_T = DataArray | DataTree
ColorLike = tuple[float, ...] | str
2 changes: 1 addition & 1 deletion src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]:
others = list(data.dims)
others.remove(axis)
# mypy (luca's pycharm config) can't see the isclose method of dask array
s = da.isclose(data.sum(dim=others), 0) # type: ignore[attr-defined]
s = da.isclose(data.sum(dim=others), 0)
# TODO: rewrite this to use dask array; can't get it to work with it
x = s.compute()
non_zero = np.where(x == 0)[0]
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(
**dict(rasterize_kwargs),
)
if rasterize
else bounding_box_query # type: ignore[assignment]
else bounding_box_query
)
self._return = self._get_return(return_annotations, table_name)
self.transform = transform
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _image_blobs(
masks = []
for i in range(n_channels):
mask = self._generate_blobs(length=length, seed=i)
mask = (mask - mask.min()) / np.ptp(mask) # type: ignore[attr-defined]
mask = (mask - mask.min()) / np.ptp(mask)
masks.append(mask)

x = np.stack(masks, axis=0)
Expand Down
12 changes: 6 additions & 6 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def parse(
else:
if len(set(dims).symmetric_difference(cls.dims.dims)) > 0:
raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.")
_reindex = lambda d: dims.index(d) # type: ignore[union-attr]
_reindex = lambda d: dims.index(d)
else:
raise ValueError(f"Unsupported data type: {type(data)}.")

Expand Down Expand Up @@ -717,7 +717,7 @@ def _(
stacklevel=2,
)
if isinstance(data, pd.DataFrame):
table: DaskDataFrame = dd.from_pandas( # type: ignore[attr-defined]
table: DaskDataFrame = dd.from_pandas(
pd.DataFrame(data[[coordinates[ax] for ax in axes]].to_numpy(), columns=axes, index=data.index),
# we need to pass sort=True also when the index is sorted to ensure that the divisions are computed
sort=sort,
Expand All @@ -731,9 +731,9 @@ def _(
data[feature_key].astype(str).astype("category"),
sort=sort,
**kwargs,
) # type: ignore[attr-defined]
)
table[feature_key] = feature_categ
elif isinstance(data, dd.DataFrame): # type: ignore[attr-defined]
elif isinstance(data, dd.DataFrame):
table = data[[coordinates[ax] for ax in axes]]
table.columns = axes
if feature_key is not None:
Expand Down Expand Up @@ -774,7 +774,7 @@ def _add_metadata_and_validate(
instance_key: str | None = None,
transformations: MappingToCoordinateSystem_t | None = None,
) -> DaskDataFrame:
assert isinstance(data, dd.DataFrame) # type: ignore[attr-defined]
assert isinstance(data, dd.DataFrame)
if feature_key is not None or instance_key is not None:
data.attrs[ATTRS_KEY] = {}
if feature_key is not None:
Expand All @@ -797,7 +797,7 @@ def _add_metadata_and_validate(
_parse_transformations(data, transformations)
cls.validate(data)
# false positive with the PyCharm mypy plugin
return data # type: ignore[no-any-return]
return data


class TableModel:
Expand Down
24 changes: 12 additions & 12 deletions src/spatialdata/transformations/ngff/ngff_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(
self.affine = self._parse_list_into_array(affine)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
assert isinstance(d["affine"], list)
last_row = [[0.0] * (len(d["affine"][0]) - 1) + [1.0]]
return cls(d["affine"] + last_row)
Expand Down Expand Up @@ -340,7 +340,7 @@ def transform_points(self, points: ArrayLike) -> ArrayLike:
self._validate_transform_points_shapes(len(input_axes), points.shape)
p = np.vstack([points.T, np.ones(points.shape[0])])
q = self.affine @ p
return q[: len(output_axes), :].T # type: ignore[no-any-return]
return q[: len(output_axes), :].T

def to_affine(self) -> "NgffAffine":
return NgffAffine(
Expand Down Expand Up @@ -411,7 +411,7 @@ def __init__(

# TODO: remove type: ignore[valid-type] when https://github.com/python/mypy/pull/14041 is merged
@classmethod
def _from_dict(cls, _: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, _: Transformation_t) -> Self:
return cls()

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -478,7 +478,7 @@ def __init__(
self.map_axis = map_axis

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls(d["mapAxis"])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -569,7 +569,7 @@ def __init__(
self.translation = self._parse_list_into_array(translation)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls(d["translation"])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -636,7 +636,7 @@ def __init__(
self.scale = self._parse_list_into_array(scale)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls(d["scale"])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -705,13 +705,13 @@ def __init__(
self.rotation = self._parse_list_into_array(rotation)

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
x = d["rotation"]
n = len(x)
r = math.sqrt(n)
assert n == int(r * r)
m = np.array(x).reshape((int(r), int(r))).tolist()
return cls(m)
return cls(m) # type: ignore[arg-type]

def to_dict(self) -> Transformation_t:
d = {
Expand Down Expand Up @@ -802,7 +802,7 @@ def __init__(
self.output_coordinate_system = cs

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls([NgffBaseTransformation.from_dict(t) for t in d["transformations"]])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -941,7 +941,7 @@ def to_affine(self) -> NgffAffine:
for t in self.transformations:
latest_output_cs, input_cs, output_cs = NgffSequence._inferring_cs_pre_action(t, latest_output_cs)
a = t.to_affine()
composed = a.affine @ composed
composed = a.affine @ composed # type: ignore[assignment]
NgffSequence._inferring_cs_post_action(t, input_cs, output_cs)
if output_axes != latest_output_cs.axes_names:
raise ValueError(
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def __init__(
self.transformations = transformations

@classmethod
def _from_dict(cls, d: Transformation_t) -> Self: # type: ignore[valid-type]
def _from_dict(cls, d: Transformation_t) -> Self:
return cls([NgffBaseTransformation.from_dict(t) for t in d["transformations"]])

def to_dict(self) -> Transformation_t:
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def transform_points(self, points: ArrayLike) -> ArrayLike:
input_columns_stacked: ArrayLike = np.stack(input_columns, axis=1)
output_columns_t = t.transform_points(input_columns_stacked)
for ax, col in zip(t.output_coordinate_system.axes_names, output_columns_t.T, strict=True):
output_columns[ax] = col
output_columns[ax] = col # type: ignore[assignment]
output: ArrayLike = np.stack([output_columns[ax] for ax in output_axes], axis=1)
return output

Expand Down
Loading

0 comments on commit 803a66e

Please sign in to comment.