Skip to content

Commit

Permalink
Merge pull request #119 from bioimage-io/weight-utils
Browse files Browse the repository at this point in the history
Add utility function for adding additional weight formats to model spec
  • Loading branch information
constantinpape authored Jun 17, 2021
2 parents afff3a4 + 0316007 commit 7c70d8c
Show file tree
Hide file tree
Showing 18 changed files with 212 additions and 197 deletions.
2 changes: 1 addition & 1 deletion bioimageio/spec/latest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from bioimageio.spec.v0_3 import * # noqa
from .build_spec import build_spec
from .build_spec import add_weights, build_spec, serialize_spec
37 changes: 33 additions & 4 deletions bioimageio/spec/latest/build_spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import datetime
import hashlib
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy as np

import bioimageio.spec as spec


#
# utility functions to build the spec from python
#
Expand Down Expand Up @@ -48,7 +51,6 @@ def _infer_weight_type(path):
raise ValueError(f"Could not infer weight type from extension {ext} for weight file {path}")


# TODO extend supported weight types
def _get_weights(weight_uri, weight_type, source, root, **kwargs):
weight_path = _get_local_path(weight_uri, root)
if weight_type is None:
Expand Down Expand Up @@ -213,12 +215,18 @@ def _get_output_tensor(test_out, name, reference_input, scale, offset, axes, dat
return outputs


def _build_authors(authors: List[Dict[str, str]]):
return [spec.raw_nodes.Author(**a) for a in authors]


# TODO The citation entry should be improved so that we can properly derive doi vs. url
def _build_cite(cite):
def _build_cite(cite: Dict[str, str]):
citation_list = [spec.raw_nodes.CiteEntry(text=k, url=v) for k, v in cite.items()]
return citation_list


# TODO we should make the name more specific: "build_model_spec"?
# TODO maybe "build_raw_model" as it return raw_nodes.Model
# NOTE does not support multiple input / output tensors yet
# to implement this we should wait for 0.4.0, see also
# https://github.com/bioimage-io/spec-bioimage-io/issues/70#issuecomment-825737433
Expand Down Expand Up @@ -384,8 +392,13 @@ def build_spec(
}
kwargs = {k: v for k, v in optional_kwargs.items() if v is not None}

# build the citation object
# build raw_nodes objects
authors = _build_authors(authors)
cite = _build_cite(cite)
documentation = Path(documentation)
covers = [spec.fields.URI().deserialize(uri) for uri in covers]
test_inputs = [spec.fields.URI().deserialize(uri) for uri in test_inputs]
test_outputs = [spec.fields.URI().deserialize(uri) for uri in test_outputs]

model = spec.raw_nodes.Model(
format_version=format_version,
Expand All @@ -412,3 +425,19 @@ def build_spec(
model = spec.schema.Model().load(serialized)

return model


def add_weights(model, weight_uri: str, root: Optional[str] = None, weight_type: Optional[str] = None, **weight_kwargs):
"""Add weight entry to bioimage.io model."""
new_weights = _get_weights(weight_uri, weight_type, None, root, **weight_kwargs)[0]
model.weights.update(new_weights)

serialized = spec.schema.Model().dump(model)
model = spec.schema.Model().load(serialized)

return model


def serialize_spec(model, out_path): # TODO change name to include model (see build_model_spec)
serialized = spec.schema.Model().dump(model)
spec.utils.yaml.dump(serialized, out_path)
4 changes: 2 additions & 2 deletions bioimageio/spec/shared/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ def __init__(self, **super_kwargs):


class Kwargs(Dict):
def __init__(self, keys=String, missing=dict, bioimageio_description="Key word arguments.", **super_kwargs):
super().__init__(keys, missing=missing, bioimageio_description=bioimageio_description, **super_kwargs)
def __init__(self, keys=String, bioimageio_description="Key word arguments.", **super_kwargs):
super().__init__(keys, bioimageio_description=bioimageio_description, **super_kwargs)


class OutputShape(Union):
Expand Down
16 changes: 13 additions & 3 deletions bioimageio/spec/shared/raw_nodes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
"""shared raw nodes that shared transformer act on"""

import dataclasses
from dataclasses import dataclass
from pathlib import Path
from typing import List
from typing import List, Union

try:
from typing import get_args, get_origin
except ImportError:
from typing_extensions import get_args, get_origin

from marshmallow import missing


@dataclass
class Node:
pass
def __post_init__(self):
for f in dataclasses.fields(self):
if getattr(self, f.name) is missing and (
get_origin(f.type) is not Union or not isinstance(missing, get_args(f.type))
):
raise TypeError(f"{self.__class__}.__init__() missing required argument: '{f.name}'")


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions bioimageio/spec/shared/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import raw_nodes


class SharedPyBioSchema(Schema):
class SharedBioImageIOSchema(Schema):
raw_nodes: ModuleType = raw_nodes # should be overwritten in subclass by version specific raw nodes module
bioimageio_description: str = ""

Expand All @@ -31,7 +31,7 @@ def make_object(self, data, **kwargs):
raise e


class ImplicitInputShape(SharedPyBioSchema):
class ImplicitInputShape(SharedBioImageIOSchema):
min = fields.List(
fields.Integer, required=True, bioimageio_description="The minimum input shape with same length as `axes`"
)
Expand All @@ -50,7 +50,7 @@ def matching_lengths(self, data, **kwargs):
raise ValidationError(f"'min' and 'step' have to have the same length! (min: {min_}, step: {step})")


class ImplicitOutputShape(SharedPyBioSchema):
class ImplicitOutputShape(SharedBioImageIOSchema):
reference_input = fields.String(required=True, bioimageio_description="Name of the reference input tensor.")
scale = fields.List(
fields.Float, required=True, bioimageio_description="'output_pix/input_pix' for each dimension."
Expand Down
4 changes: 4 additions & 0 deletions bioimageio/spec/v0_3/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def maybe_convert_model(data: Dict[str, Any]) -> Dict[str, Any]:
if config.get("future") == {}:
del config["future"]

# remove 'config' if now empty
if data.get("config") == {}:
del data["config"]

return data


Expand Down
72 changes: 37 additions & 35 deletions bioimageio/spec/v0_3/raw_nodes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import distutils.version
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NewType, Optional, Tuple, Union
from typing import Any, Dict, List, NewType, Tuple, Union

from marshmallow import missing
from marshmallow.utils import _Missing

from bioimageio.spec.shared.raw_nodes import (
ImplicitInputShape,
Expand Down Expand Up @@ -51,8 +52,8 @@
@dataclass
class Author(Node):
name: str = missing
affiliation: Optional[str] = missing
orcid: Optional[str] = missing
affiliation: Union[_Missing, str] = missing
orcid: Union[_Missing, str] = missing


ImportableSource = Union[ImportableModule, ImportablePath]
Expand All @@ -61,37 +62,37 @@ class Author(Node):
@dataclass
class CiteEntry(Node):
text: str = missing
doi: Optional[str] = missing
url: Optional[str] = missing
doi: Union[_Missing, str] = missing
url: Union[_Missing, str] = missing


@dataclass
class RunMode(Node):
name: str = missing
kwargs: Dict[str, Any] = missing
kwargs: Union[_Missing, Dict[str, Any]] = missing


@dataclass
class RDF(Node):
attachments: Dict[str, Any] = missing
attachments: Union[_Missing, Dict[str, Any]] = missing
authors: List[Author] = missing
cite: List[CiteEntry] = missing
config: dict = missing
covers: List[URI] = missing
dependencies: Optional[Dependencies] = missing
config: Union[_Missing, dict] = missing
covers: Union[_Missing, List[URI]] = missing
dependencies: Union[_Missing, Dependencies] = missing
description: str = missing
documentation: URI = missing
format_version: FormatVersion = missing
framework: Framework = missing
git_repo: Optional[str] = missing
language: Language = missing
framework: Union[_Missing, Framework] = missing
git_repo: Union[_Missing, str] = missing
language: Union[_Missing, Language] = missing
license: str = missing
name: str = missing
run_mode: Optional[RunMode] = missing
run_mode: Union[_Missing, RunMode] = missing
tags: List[str] = missing
timestamp: datetime = missing
type: Type = missing
version: Optional[distutils.version.StrictVersion] = missing
version: Union[_Missing, distutils.version.StrictVersion] = missing


@dataclass
Expand All @@ -112,9 +113,9 @@ class InputTensor:
data_type: str = missing
axes: Axes = missing
shape: Union[List[int], ImplicitInputShape] = missing
preprocessing: List[Preprocessing] = missing
description: Optional[str] = missing
data_range: Tuple[float, float] = missing
preprocessing: Union[_Missing, List[Preprocessing]] = missing
description: Union[_Missing, str] = missing
data_range: Union[_Missing, Tuple[float, float]] = missing


@dataclass
Expand All @@ -123,24 +124,24 @@ class OutputTensor:
data_type: str = missing
axes: Axes = missing
shape: Union[List[int], ImplicitOutputShape] = missing
halo: List[int] = missing
postprocessing: List[Postprocessing] = missing
description: Optional[str] = missing
data_range: Tuple[float, float] = missing
halo: Union[_Missing, List[int]] = missing
postprocessing: Union[_Missing, List[Postprocessing]] = missing
description: Union[_Missing, str] = missing
data_range: Union[_Missing, Tuple[float, float]] = missing


@dataclass
class WeightsEntry(Node):
authors: List[Author] = missing
attachments: Dict = missing
parent: Optional[str] = missing
authors: Union[_Missing, List[Author]] = missing
attachments: Union[_Missing, Dict] = missing
parent: Union[_Missing, str] = missing
# ONNX specific
opset_version: Optional[int] = missing
opset_version: Union[_Missing, int] = missing
# tag: Optional[str] # todo: check schema. only valid for tensorflow_saved_model_bundle format
# todo: check schema. only valid for tensorflow_saved_model_bundle format
sha256: str = missing
sha256: Union[_Missing, str] = missing
source: URI = missing
tensorflow_version: Optional[distutils.version.StrictVersion] = missing
tensorflow_version: Union[_Missing, distutils.version.StrictVersion] = missing


@dataclass
Expand All @@ -152,16 +153,17 @@ class ModelParent(Node):
@dataclass
class Model(RDF):
inputs: List[InputTensor] = missing
kwargs: Dict[str, Any] = missing
kwargs: Union[_Missing, Dict[str, Any]] = missing
outputs: List[OutputTensor] = missing
packaged_by: List[Author] = missing
parent: ModelParent = missing
sample_inputs: List[URI] = missing
sample_outputs: List[URI] = missing
sha256: str = missing
source: Optional[ImportableSource] = missing
packaged_by: Union[_Missing, List[Author]] = missing
parent: Union[_Missing, ModelParent] = missing
sample_inputs: Union[_Missing, List[URI]] = missing
sample_outputs: Union[_Missing, List[URI]] = missing
sha256: Union[_Missing, str] = missing
source: Union[_Missing, ImportableSource] = missing
test_inputs: List[URI] = missing
test_outputs: List[URI] = missing
type: Type = "model"
weights: Dict[WeightsFormat, WeightsEntry] = missing


Expand Down
Loading

0 comments on commit 7c70d8c

Please sign in to comment.