Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve output shape validation #395

Merged
merged 3 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions bioimageio/spec/model/v0_4/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing
from copy import deepcopy

import numpy
from marshmallow import RAISE, ValidationError, missing, pre_load, validates, validates_schema

from bioimageio.spec.model.v0_3.schema import (
Expand Down Expand Up @@ -422,19 +423,64 @@ def no_duplicate_output_tensor_names(self, value: typing.List[raw_nodes.OutputTe
raise ValidationError("Duplicate output tensor names are not allowed.")

@validates_schema
def no_duplicate_tensor_names(self, data, **kwargs):
ipts = data.get("inputs")
if not isinstance(ipts, list) or not all(isinstance(v, raw_nodes.InputTensor) for v in ipts):
def inputs_and_outputs(self, data, **kwargs):
ipts: typing.List[raw_nodes.InputTensor] = data.get("inputs")
outs: typing.List[raw_nodes.OutputTensor] = data.get("outputs")
if any(
[
not isinstance(ipts, list),
not isinstance(outs, list),
not all(isinstance(v, raw_nodes.InputTensor) for v in ipts),
not all(isinstance(v, raw_nodes.OutputTensor) for v in outs),
]
):
raise ValidationError("Could not check for duplicate tensor names due to another validation error.")

outs = data.get("outputs")
if not isinstance(outs, list) or not all(isinstance(v, raw_nodes.OutputTensor) for v in outs):
raise ValidationError("Could not check for duplicate tensor names due to another validation error.")

names = [t.name for t in data["inputs"] + data["outputs"]]
# no duplicate tensor names
names = [t.name for t in ipts + outs] # type: ignore
if len(names) > len(set(names)):
raise ValidationError("Duplicate tensor names are not allowed.")

tensors_by_name: typing.Dict[str, typing.Union[raw_nodes.InputTensor, raw_nodes.OutputTensor]] = {
t.name: t for t in ipts + outs # type: ignore
}

# minimum shape leads to valid output:
# output with subtracted halo has to result in meaningful output even for the minimal input
# see https://github.com/bioimage-io/spec-bioimage-io/issues/392
def get_min_shape(t) -> numpy.ndarray:
if isinstance(t.shape, raw_nodes.ParametrizedInputShape):
shape = numpy.array(t.shape.min)
elif isinstance(t.shape, raw_nodes.ImplicitOutputShape):
shape = get_min_shape(tensors_by_name[t.shape.reference_tensor]) * t.shape.scale + 2 * numpy.array(
t.shape.offset
)
else:
shape = numpy.array(t.shape)

return shape

for out in outs:
if isinstance(out.shape, raw_nodes.ImplicitOutputShape) and len(out.shape) != len(
tensors_by_name[out.shape.reference_tensor].shape
):
raise ValidationError(
f"Referenced tensor {out.shape.reference_tensor} "
f"with {len(tensors_by_name[out.shape.reference_tensor].shape)} dimensions does not match "
f"output tensor {out.name} with {len(out.shape)} dimensions."
)

min_out_shape = get_min_shape(out)
if out.halo:
halo = out.halo
halo_msg = f" for halo {out.halo}"
else:
halo = [0] * len(min_out_shape)
halo_msg = ""

if any([s - 2 * h < 1 for s, h in zip(min_out_shape, halo)]):
raise ValidationError(f"Minimal shape {min_out_shape} of output {out.name} is too small{halo_msg}.")

test_inputs = fields.List(
fields.Union([fields.URI(), fields.RelativeLocalPath()]),
validate=field_validators.Length(min=1),
Expand Down
2 changes: 1 addition & 1 deletion example_specs/models/stardist_example_model/rdf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ inputs:
- kwargs: {axes: yx, max_percentile: 99.8, min_percentile: 1.0, mode: per_sample}
name: scale_range
shape:
min: [1, 16, 16, 1]
min: [1, 80, 80, 1]
step: [0, 16, 16, 0]
license: CC-BY-NC-4.0
name: StardistExampleModel
Expand Down
67 changes: 67 additions & 0 deletions tests/test_schema_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,70 @@ def test_model_0_4_raises_on_duplicate_tensor_names(invalid_rdf_v0_4_0_duplicate

valid_data = model_schema.load(data)
assert valid_data


def test_output_fixed_shape_too_small(model_dict):
from bioimageio.spec.model.schema import Model

model_dict["outputs"] = [
{
"name": "output_1",
"description": "Output 1",
"data_type": "float32",
"axes": "xyc",
"shape": [128, 128, 3],
"halo": [32, 128, 0],
}
]

with pytest.raises(ValidationError) as e:
Model().load(model_dict)

assert e.value.messages == {
"_schema": ["Minimal shape [128 128 3] of output output_1 is too small for halo [32, 128, 0]."]
}


def test_output_ref_shape_mismatch(model_dict):
from bioimageio.spec.model.schema import Model

model_dict["outputs"] = [
{
"name": "output_1",
"description": "Output 1",
"data_type": "float32",
"axes": "xyc",
"shape": {"reference_tensor": "input_1", "scale": [1, 2, 3, 4], "offset": [0, 0, 0, 0]},
}
]

with pytest.raises(ValidationError) as e:
Model().load(model_dict)

assert e.value.messages == {
"_schema": [
"Referenced tensor input_1 with 3 dimensions does not match output tensor output_1 with 4 dimensions."
]
}


def test_output_ref_shape_too_small(model_dict):
from bioimageio.spec.model.schema import Model

model_dict["outputs"] = [
{
"name": "output_1",
"description": "Output 1",
"data_type": "float32",
"axes": "xyc",
"shape": {"reference_tensor": "input_1", "scale": [1, 2, 3], "offset": [0, 0, 0]},
"halo": [256, 128, 0],
}
]

with pytest.raises(ValidationError) as e:
Model().load(model_dict)

assert e.value.messages == {
"_schema": ["Minimal shape [128. 256. 9.] of output output_1 is too small for halo [256, 128, 0]."]
}