Skip to content

Commit

Permalink
Merge pull request #585 from bioimage-io/fix_slim_packaging
Browse files Browse the repository at this point in the history
Fix packaging with weights format priority
  • Loading branch information
FynnBe authored Apr 22, 2024
2 parents 573e6e5 + 7b7e73d commit 6847989
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 48 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ Made with [contrib.rocks](https://contrib.rocks).

### bioimageio.spec Python package

#### bioimageio.spec 0.5.2post1

* fix model packaging with weights format priority

#### bioimageio.spec 0.5.2

* new patch version model 0.5.2
Expand Down
2 changes: 1 addition & 1 deletion bioimageio/spec/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.5.2"
"version": "0.5.2post1"
}
2 changes: 1 addition & 1 deletion bioimageio/spec/_internal/common_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class StringNode(collections.UserString, ABC):
_node_class: Type[Node]
_node: Optional[Node] = None

def __init__(self: Self, seq: object) -> None:
def __init__(self, seq: object) -> None:
super().__init__(seq)
type_hints = {
fn: t
Expand Down
17 changes: 13 additions & 4 deletions bioimageio/spec/_internal/packaging_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from contextvars import ContextVar, Token
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Sequence, Union

from .io_basics import AbsoluteFilePath, FileName
from .url import HttpUrl
Expand All @@ -16,16 +16,20 @@ class PackagingContext:

bioimageio_yaml_file_name: FileName

file_sources: Dict[FileName, Union[AbsoluteFilePath, HttpUrl]] = field(
default_factory=dict
)
file_sources: Dict[FileName, Union[AbsoluteFilePath, HttpUrl]]
"""File sources to include in the packaged resource"""

weights_priority_order: Optional[Sequence[str]] = None
"""set to select a single weights entry when packaging model resources"""

def replace(
self,
*,
bioimageio_yaml_file_name: Optional[FileName] = None,
file_sources: Optional[Dict[FileName, Union[AbsoluteFilePath, HttpUrl]]] = None,
weights_priority_order: Union[
Optional[Sequence[str]], Literal["unchanged"]
] = "unchanged",
) -> "PackagingContext":
"""return a modiefied copy"""
return PackagingContext(
Expand All @@ -37,6 +41,11 @@ def replace(
file_sources=(
dict(self.file_sources) if file_sources is None else file_sources
),
weights_priority_order=(
self.weights_priority_order
if weights_priority_order == "unchanged"
else weights_priority_order
),
)

def __enter__(self):
Expand Down
42 changes: 3 additions & 39 deletions bioimageio/spec/_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from ._internal.validation_context import validation_context_var
from ._internal.warning_levels import ERROR
from ._io import load_description
from .model.v0_4 import ModelDescr as ModelDescr04
from .model.v0_4 import WeightsFormat
from .model.v0_5 import ModelDescr as ModelDescr05


def get_os_friendly_file_name(name: str) -> str:
Expand Down Expand Up @@ -59,50 +57,16 @@ def get_resource_package_content(
)
content: Dict[FileName, Union[HttpUrl, AbsoluteFilePath]] = {}
with PackagingContext(
bioimageio_yaml_file_name=bioimageio_yaml_file_name, file_sources=content
bioimageio_yaml_file_name=bioimageio_yaml_file_name,
file_sources=content,
weights_priority_order=weights_priority_order,
):
rdf_content: BioimageioYamlContent = rd.model_dump(
mode="json", exclude_unset=True
)

_ = rdf_content.pop("rdf_source", None)

if weights_priority_order is not None and isinstance(
rd, (ModelDescr04, ModelDescr05)
):
# select single weights entry
assert isinstance(rdf_content["weights"], dict), type(rdf_content["weights"])
for wf in weights_priority_order:
w = rdf_content["weights"].get(wf)
if w is not None:
break
else:
raise ValueError(
"None of the weight formats in `weights_priority_order` is present in"
+ " the given model."
)

assert isinstance(w, dict), type(w)
_ = w.pop("parent", None)
rdf_content["weights"] = {wf: w}
parent = rdf_content.pop("id", None)
parent_version = rdf_content.pop("version", None)
if parent is not None:
rdf_content["parent"] = {"id": parent, "version": parent_version}

with validation_context_var.get().replace(
root=rd.root, file_name=bioimageio_yaml_file_name
):
rd_slim = build_description(rdf_content)

assert not isinstance(
rd_slim, InvalidDescr
), rd_slim.validation_summary.format()
# repackage without other weights entries
return get_resource_package_content(
rd_slim, bioimageio_yaml_file_name=bioimageio_yaml_file_name
)

return {**content, bioimageio_yaml_file_name: rdf_content}


Expand Down
41 changes: 39 additions & 2 deletions bioimageio/spec/model/v0_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
AllowInfNan,
Discriminator,
Field,
SerializationInfo,
SerializerFunctionWrapHandler,
TypeAdapter,
ValidationInfo,
WrapSerializer,
field_validator,
model_validator,
)
Expand All @@ -46,6 +49,7 @@
from .._internal.io import FileDescr as FileDescr
from .._internal.io import Sha256 as Sha256
from .._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
from .._internal.packaging_context import packaging_context_var
from .._internal.types import Datetime as Datetime
from .._internal.types import Identifier as Identifier
from .._internal.types import ImportantFileSource, LowerCaseIdentifier
Expand Down Expand Up @@ -870,6 +874,37 @@ class LinkedModel(Node):
"""version number (n-th published version, not the semantic version) of linked model"""


def package_weights(
value: Node,
handler: SerializerFunctionWrapHandler,
info: SerializationInfo,
):
ctxt = packaging_context_var.get()
if ctxt is not None and ctxt.weights_priority_order is not None:
for wf in ctxt.weights_priority_order:
w = getattr(value, wf, None)
if w is not None:
break
else:
raise ValueError(
"None of the weight formats in `weights_priority_order`"
+ f" ({ctxt.weights_priority_order}) is present in the given model."
)

# remove links to parent entry (otherwise we cannot remove the parent)
for _, w in value:
if w is not None:
w.parent = None

for field_name in value.model_fields:
if field_name != wf:
setattr(value, field_name, None)

return handler(
value, info # pyright: ignore[reportArgumentType] # taken from pydantic docs
)


class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"):
"""Specification of the fields used in a bioimage.io-compliant RDF that describes AI models with pretrained weights.
Expand All @@ -888,7 +923,9 @@ class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification")
id: Optional[ModelId] = None
"""Model zoo (bioimage.io) wide, unique identifier (assigned by bioimage.io)"""

authors: NotEmpty[List[Author]]
authors: NotEmpty[ # pyright: ignore[reportGeneralTypeIssues] # make mandatory
List[Author]
]
"""The authors are the creators of the model RDF and the primary points of contact."""

documentation: Annotated[
Expand Down Expand Up @@ -1114,7 +1151,7 @@ def ignore_url_parent(cls, parent: Any):
training_data: Union[LinkedDataset, DatasetDescr, None] = None
"""The dataset used to train this model"""

weights: WeightsDescr
weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
"""The weights for this model.
Weights can be given for different formats, but should otherwise be equivalent.
The available weight formats determine which consumers can use this model."""
Expand Down
4 changes: 3 additions & 1 deletion bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
RootModel,
Tag,
ValidationInfo,
WrapSerializer,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -123,6 +124,7 @@
from .v0_4 import TensorName as _TensorName_v0_4
from .v0_4 import WeightsFormat as WeightsFormat
from .v0_4 import ZeroMeanUnitVarianceDescr as _ZeroMeanUnitVarianceDescr_v0_4
from .v0_4 import package_weights

# unit names from https://ngff.openmicroscopy.org/latest/#axes-md
SpaceUnit = Literal[
Expand Down Expand Up @@ -2342,7 +2344,7 @@ def _validate_output_axes(
training_data: Union[None, LinkedDataset, DatasetDescr] = None
"""The dataset used to train this model"""

weights: WeightsDescr
weights: Annotated[WeightsDescr, WrapSerializer(package_weights)]
"""The weights for this model.
Weights can be given for different formats, but should otherwise be equivalent.
The available weight formats determine which consumers can use this model."""
Expand Down

0 comments on commit 6847989

Please sign in to comment.