Skip to content

Commit

Permalink
Merge pull request #157 from scverse/io/fixes
Browse files Browse the repository at this point in the history
minor fix for points model parser
  • Loading branch information
LucaMarconato authored Feb 27, 2023
2 parents 8724f8a + b30886e commit 4fb3433
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
30 changes: 19 additions & 11 deletions spatialdata/_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame:
In the case of (Multi)`Polygons` shapes, the offsets of the polygons must be provided.
radius
Array of size of the `Circles`. It must be provided if the shapes are `Circles`.
index
Index of the shapes, must be of type `str`. If None, it's generated automatically.
transform
Transform of points.
kwargs
Expand All @@ -376,6 +378,7 @@ def _(
geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON]
offsets: Optional[tuple[ArrayLike, ...]] = None,
radius: Optional[ArrayLike] = None,
index: Optional[ArrayLike] = None,
transformations: Optional[MappingToCoordinateSystem_t] = None,
) -> GeoDataFrame:
geometry = GeometryType(geometry)
Expand All @@ -385,6 +388,8 @@ def _(
if radius is None:
raise ValueError("If `geometry` is `Circles`, `radius` must be provided.")
geo_df[cls.RADIUS_KEY] = radius
if index is not None:
geo_df.index = index
_parse_transformations(geo_df, transformations)
cls.validate(geo_df)
return geo_df
Expand All @@ -396,6 +401,7 @@ def _(
cls,
data: Union[str, Path],
radius: Optional[ArrayLike] = None,
index: Optional[ArrayLike] = None,
transformations: Optional[Any] = None,
**kwargs: Any,
) -> GeoDataFrame:
Expand All @@ -411,6 +417,8 @@ def _(
if radius is None:
raise ValueError("If `geometry` is `Circles`, `radius` must be provided.")
geo_df[cls.RADIUS_KEY] = radius
if index is not None:
geo_df.index = index
_parse_transformations(geo_df, transformations)
cls.validate(geo_df)
return geo_df
Expand Down Expand Up @@ -457,17 +465,6 @@ def validate(cls, data: DaskDataFrame) -> None:
logger.info(
f"Instance key `{instance_key}` could be of type `pd.Categorical`. Consider casting it."
)
# commented out to address this issue: https://github.com/scverse/spatialdata/issues/140
# for c in data.columns:
# # this is not strictly a validation since we are explicitly importing the categories
# # but it is a convenient way to ensure that the categories are known. It also just changes the state of the
# # series, so it is not a big deal.
# if is_categorical_dtype(data[c]):
# if not data[c].cat.known:
# try:
# data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
# except ValueError:
# logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")

@singledispatchmethod
@classmethod
Expand Down Expand Up @@ -593,6 +590,17 @@ def _add_metadata_and_validate(
assert instance_key in data.columns
data.attrs[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key

for c in data.columns:
# Here we are explicitly importing the categories
# but it is a convenient way to ensure that the categories are known.
# It also just changes the state of the series, so it is not a big deal.
if is_categorical_dtype(data[c]):
if not data[c].cat.known:
try:
data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
except ValueError:
logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")

_parse_transformations(data, transformations)
cls.validate(data)
# false positive with the PyCharm mypy plugin
Expand Down
2 changes: 2 additions & 0 deletions spatialdata/_io/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def write_shapes(
shapes_group.create_dataset(name="coords", data=coords)
for i, o in enumerate(offsets):
shapes_group.create_dataset(name=f"offset{i}", data=o)
# index cannot be string
# https://github.com/zarr-developers/zarr-python/issues/1090
shapes_group.create_dataset(name="Index", data=shapes.index.values)
if geometry.name == "POINT":
shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values)
Expand Down

0 comments on commit 4fb3433

Please sign in to comment.