Skip to content

Commit

Permalink
Fixed and tested index bug points model (#471)
Browse files Browse the repository at this point in the history
fixed and tested index bug points model
  • Loading branch information
LucaMarconato authored Feb 25, 2024
1 parent 9cf12d3 commit 6f71197
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ unfixable = ["B", "C4", "UP", "BLE", "T20", "RET"]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "PT", "B024"]
"*/__init__.py" = ["F401", "D104", "D107", "E402"]
"docs/*" = ["D","B","E","A"]
Expand Down
9 changes: 6 additions & 3 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def parse(cls, data: Any, **kwargs: Any) -> DaskDataFrame:
with key as *valid axes* and value as column names in dataframe.
annotation
Annotation dataframe. Only if `data` is :class:`numpy.ndarray`.
Annotation dataframe. Only if `data` is :class:`numpy.ndarray`. If data is an array, the index of the
annotations will be used as the index of the parsed points.
coordinates
Mapping of axes names (keys) to column names (valus) in `data`. Only if `data` is
:class:`pandas.DataFrame`. Example: {'x': 'my_x_column', 'y': 'my_y_column'}.
Expand Down Expand Up @@ -535,7 +536,8 @@ def _(
assert len(data.shape) == 2
ndim = data.shape[1]
axes = [X, Y, Z][:ndim]
table: DaskDataFrame = dd.from_pandas(pd.DataFrame(data, columns=axes), **kwargs) # type: ignore[attr-defined]
index = annotation.index if annotation is not None else None
table: DaskDataFrame = dd.from_pandas(pd.DataFrame(data, columns=axes, index=index), **kwargs) # type: ignore[attr-defined]
if annotation is not None:
if feature_key is not None:
feature_categ = dd.from_pandas( # type: ignore[attr-defined]
Expand Down Expand Up @@ -579,7 +581,8 @@ def _(
axes = [X, Y, Z][:ndim]
if isinstance(data, pd.DataFrame):
table: DaskDataFrame = dd.from_pandas( # type: ignore[attr-defined]
pd.DataFrame(data[[coordinates[ax] for ax in axes]].to_numpy(), columns=axes), **kwargs
pd.DataFrame(data[[coordinates[ax] for ax in axes]].to_numpy(), columns=axes, index=data.index),
**kwargs,
)
if feature_key is not None:
feature_categ = dd.from_pandas(
Expand Down
12 changes: 8 additions & 4 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,13 @@ def test_points_model(
if coordinates is not None:
coordinates = coordinates.copy()
coords = ["A", "B", "C", "x", "y", "z"]
data = pd.DataFrame(RNG.integers(0, 101, size=(10, 6)), columns=coords)
data["target"] = pd.Series(RNG.integers(0, 2, size=(10,))).astype(str)
data["cell_id"] = pd.Series(RNG.integers(0, 5, size=(10,))).astype(np.int_)
data["anno"] = pd.Series(RNG.integers(0, 1, size=(10,))).astype(np.int_)
n = 10
data = pd.DataFrame(RNG.integers(0, 101, size=(n, 6)), columns=coords)
data["target"] = pd.Series(RNG.integers(0, 2, size=(n,))).astype(str)
data["cell_id"] = pd.Series(RNG.integers(0, 5, size=(n,))).astype(np.int_)
data["anno"] = pd.Series(RNG.integers(0, 1, size=(n,))).astype(np.int_)
# to test for non-contiguous indices
data.drop(index=2, inplace=True)
if not is_3d:
if coordinates is not None:
del coordinates["z"]
Expand Down Expand Up @@ -296,6 +299,7 @@ def test_points_model(
for axis in axes:
assert np.array_equal(points[axis], data[coordinates[axis]])
self._passes_validation_after_io(model, points, "points")
assert np.all(points.index.compute() == data.index)
assert "transform" in points.attrs
if feature_key is not None and is_annotation:
assert "spatialdata_attrs" in points.attrs
Expand Down

0 comments on commit 6f71197

Please sign in to comment.