Skip to content

Commit

Permalink
Introduce prepare_label_group helper function, and use it in cellpo…
Browse files Browse the repository at this point in the history
…se task (ref #458)
  • Loading branch information
tcompa committed Aug 30, 2023
1 parent f23236a commit c6dcae3
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 41 deletions.
104 changes: 97 additions & 7 deletions fractal_tasks_core/lib_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,16 @@ def write_table(
image_group:
The group to write to.
table_name:
The key to write to in the group. Note that absolute paths will be
written from the root.
The name of the new table.
table:
The AnnData table to write.
overwrite:
TBD
TBD.
table_attrs:
If set, overwrite all attributes of the new-table zarr group with
the key/value pairs in `table_attrs`.
If set, overwrite table_group attributes with table_attrs key/value
pairs.
logger:
The logger to use (if unset, use `logging.getLogger(None)`)
The logger to use (if unset, use `logging.getLogger(None)`).
Returns:
Zarr group of the new table.
Expand Down Expand Up @@ -273,7 +272,98 @@ def write_table(
"proposed table specs. "
f"Original error: KeyError: {str(e)}"
)
# Overwrite all table_attrs key/value pairs into table_group attributes
# Overwrite table_group attributes with table_attrs key/value pairs
table_group.attrs.put(table_attrs)

return table_group


def prepare_label_group(
image_group: zarr.Group,
label_name: str,
overwrite: bool = False,
label_attrs: Optional[dict[str, Any]] = None,
logger: Optional[logging.Logger] = None,
) -> zarr.group:
"""
Set the stage for writing labels to a zarr group
This helper function is similar to `write_table`, in that it prepares the
appropriate zarr groups (`labels` and the new-label one) and performs
`overwrite`-dependent checks; at a difference with `write_table`, this
function does not actually write the label array to the new zarr group;
such writing operation must take place in the actual task function, since
in fractal-tasks-core it is done sequentially on different `region`s of the
zarr array.
What this function does is:
1. Create the `labels` group, if needed.
2. If `overwrite=False`, check that the new label does not exist (either in
zarr attributes or as a zarr sub-group).
3. Update the `labels` attribute of the image group.
4. If `label_attrs` is set, include this set of attributes in the
new-label zarr group.
Args:
image_group:
The group to write to.
label_name:
The name of the new label.
overwrite:
TBD
label_attrs:
If set, overwrite label_group attributes with label_attrs key/value
pairs.
logger:
The logger to use (if unset, use `logging.getLogger(None)`).
Returns:
Zarr group of the new label.
"""

# Set logger
if logger is None:
logger = logging.getLogger(None)

# Create labels group (if needed) and extract current_labels
if "labels" not in set(image_group.group_keys()):
labels_group = image_group.create_group("labels", overwrite=False)
else:
labels_group = image_group["labels"]
current_labels = labels_group.attrs.asdict().get("labels", [])

# If overwrite=False, check that the new label does not exist (either as a
# zarr sub-group or as part of the zarr-group attributes)
if not overwrite:
if label_name in set(labels_group.group_keys()):
error_msg = (
f"Sub-group '{label_name}' of group {image_group.store.path} "
f"already exists, but `{overwrite=}`. "
"Hint: try setting `overwrite=True`."
)
logger.error(error_msg)
raise OverwriteNotAllowedError(error_msg)
if label_name in current_labels:
error_msg = (
f"Item '{label_name}' already exists in `labels` attribute of "
f"group {image_group.store.path}, but `{overwrite=}`. "
"Hint: try setting `overwrite=True`."
)
logger.error(error_msg)
raise OverwriteNotAllowedError(error_msg)

# Update the `labels` metadata of the image group, if needed
if label_name not in current_labels:
new_labels = current_labels + [label_name]
labels_group.attrs["labels"] = new_labels

# Define new-label group
label_group = labels_group.create_group(label_name)

# Optionally update attributes of the new-table zarr group
if label_attrs is not None:
# Overwrite label_group attributes with label_attrs key/value pairs
label_group.attrs.put(label_attrs)

return label_group
64 changes: 30 additions & 34 deletions fractal_tasks_core/tasks/cellpose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from fractal_tasks_core.lib_regions_of_interest import load_region
from fractal_tasks_core.lib_ROI_overlaps import find_overlaps_in_ROI_indices
from fractal_tasks_core.lib_ROI_overlaps import get_overlapping_pairs_3D
from fractal_tasks_core.lib_zarr import prepare_label_group
from fractal_tasks_core.lib_zattrs_utils import extract_zyx_pixel_sizes
from fractal_tasks_core.lib_zattrs_utils import rescale_datasets

Expand Down Expand Up @@ -423,42 +424,37 @@ def cellpose_segmentation(
remove_channel_axis=True,
)

# Write zattrs for labels and for specific label
new_labels = [output_label_name]
try:
with open(f"{zarrurl}/labels/.zattrs", "r") as f_zattrs:
existing_labels = json.load(f_zattrs)["labels"]
except FileNotFoundError:
existing_labels = []
intersection = set(new_labels) & set(existing_labels)
logger.info(f"{new_labels=}")
logger.info(f"{existing_labels=}")
if intersection:
raise RuntimeError(
f"Labels {intersection} already exist but are also part of outputs"
)
labels_group = zarr.group(f"{zarrurl}/labels")
labels_group.attrs["labels"] = existing_labels + new_labels

label_group = labels_group.create_group(output_label_name)
label_group.attrs["image-label"] = {
"version": __OME_NGFF_VERSION__,
"source": {"image": "../../"},
}
label_group.attrs["multiscales"] = [
{
"name": output_label_name,
label_attrs = {
"image-label": {
"version": __OME_NGFF_VERSION__,
"axes": [
ax for ax in multiscales[0]["axes"] if ax["type"] != "channel"
],
"datasets": new_datasets,
}
]
"source": {"image": "../../"},
},
"multiscales": [
{
"name": output_label_name,
"version": __OME_NGFF_VERSION__,
"axes": [
ax
for ax in multiscales[0]["axes"]
if ax["type"] != "channel"
],
"datasets": new_datasets,
}
],
}

image_group = zarr.group(zarrurl)
label_group = prepare_label_group(
image_group,
output_label_name,
overwrite=overwrite,
label_attrs=label_attrs,
logger=logger,
)

# Open new zarr group for mask 0-th level
zarr.group(f"{zarrurl}/labels")
zarr.group(f"{zarrurl}/labels/{output_label_name}")
logger.info(
f"Helper function `prepare_label_group` returned {label_group=}"
)
logger.info(f"Output label path: {zarrurl}/labels/{output_label_name}/0")
store = zarr.storage.FSStore(f"{zarrurl}/labels/{output_label_name}/0")
label_dtype = np.uint32
Expand Down

0 comments on commit c6dcae3

Please sign in to comment.