Skip to content

Commit

Permalink
Merge pull request #564 from bioimage-io/data_dep_size
Browse files Browse the repository at this point in the history
Add DataDependentSize
  • Loading branch information
FynnBe authored Mar 18, 2024
2 parents 7581247 + 404d58c commit d8ce5d0
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 26 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ Made with [contrib.rocks](https://contrib.rocks).
#### model 0.5.1 (planned)

* Non-breaking changes
* added `DataDependentSize` for `outputs.i.size` to specify an output shape that is not known before inference is run.
* added optional `inputs.i.optional` field to indicate that a tensor may be `None`

#### generic 0.3.0 / application 0.3.0 / collection 0.3.0 / dataset 0.3.0 / notebook 0.3.0
Expand Down
22 changes: 11 additions & 11 deletions bioimageio/spec/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,24 @@
SI_UNIT_REGEX = f"^{_unit_ap}((·{_unit_ap})|(/{_unit_pp}))*$"


class MinMax(NamedTuple):
class _DtypeLimit(NamedTuple):
min: Union[int, float]
max: Union[int, float]


# numpy.dtype limits; see scripts/generate_dtype_limits.py
DTYPE_LIMITS = MappingProxyType(
{
"float32": MinMax(-3.4028235e38, 3.4028235e38),
"float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
"uint8": MinMax(0, 255),
"int8": MinMax(-128, 127),
"uint16": MinMax(0, 65535),
"int16": MinMax(-32768, 32767),
"uint32": MinMax(0, 4294967295),
"int32": MinMax(-2147483648, 2147483647),
"uint64": MinMax(0, 18446744073709551615),
"int64": MinMax(-9223372036854775808, 9223372036854775807),
"float32": _DtypeLimit(-3.4028235e38, 3.4028235e38),
"float64": _DtypeLimit(-1.7976931348623157e308, 1.7976931348623157e308),
"uint8": _DtypeLimit(0, 255),
"int8": _DtypeLimit(-128, 127),
"uint16": _DtypeLimit(0, 65535),
"int16": _DtypeLimit(-32768, 32767),
"uint32": _DtypeLimit(0, 4294967295),
"int32": _DtypeLimit(-2147483648, 2147483647),
"uint64": _DtypeLimit(0, 18446744073709551615),
"int64": _DtypeLimit(-9223372036854775808, 9223372036854775807),
}
)

Expand Down
110 changes: 95 additions & 15 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
List,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -247,6 +248,27 @@ def get_size(self, n: ParameterizedSize.N) -> int:
ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)


class DataDependentSize(Node):
min: Annotated[int, Gt(0)] = 1
max: Annotated[Optional[int], Gt(1)] = None

@model_validator(mode="after")
def _validate_max_gt_min(self):
if self.max is None or self.min >= self.max:
raise ValueError(f"expected `min` <= `max`, but got {self.min}, {self.max}")

return self

def validate_size(self, size: int) -> int:
if size < self.min:
raise ValueError(f"size {size} < {self.min}")

if self.max is not None and size > self.max:
raise ValueError(f"size {size} > {self.max}")

return size


class SizeReference(Node):
"""A tensor axis size (extent in pixels/frames) defined in relation to a reference axis.
Expand Down Expand Up @@ -326,6 +348,10 @@ def get_size(
ref_size = ref_axis.size
elif isinstance(ref_axis.size, ParameterizedSize):
ref_size = ref_axis.size.get_size(n)
elif isinstance(ref_axis.size, DataDependentSize):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be a `DataDependentSize`."
)
elif isinstance(ref_axis.size, SizeReference):
raise ValueError(
"Reference axis referenced in `SizeReference` may not be sized by a"
Expand Down Expand Up @@ -451,16 +477,30 @@ class _WithOutputAxisSize(Node):
"""The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (`SizeReference`)
# TODO: add `DataDependentSize(min, max, step)`
"""


class IndexInputAxis(IndexAxisBase, _WithInputAxisSize):
pass


class IndexOutputAxis(IndexAxisBase, _WithOutputAxisSize):
pass
class IndexOutputAxis(IndexAxisBase):
size: Annotated[
Union[Annotated[int, Gt(0)], SizeReference, DataDependentSize],
Field(
examples=[
10,
SizeReference(
tensor_id=TensorId("t"), axis_id=AxisId("a"), offset=5
).model_dump(mode="json"),
]
),
]
"""The size/length of this axis can be specified as
- fixed integer
- reference to another axis with an optional offset (`SizeReference`)
- data dependent size using `DataDependentSize` (size is only known after model inference)
"""


class TimeAxisBase(AxisBase):
Expand Down Expand Up @@ -960,7 +1000,7 @@ def _validate_sample_tensor(self) -> Self:
elif isinstance(a.size, int):
if a.size == 1:
n_dims_min -= 1
elif isinstance(a.size, ParameterizedSize):
elif isinstance(a.size, (ParameterizedSize, DataDependentSize)):
if a.size.min == 1:
n_dims_min -= 1
elif isinstance(a.size, SizeReference):
Expand Down Expand Up @@ -1446,6 +1486,8 @@ def e_msg(d: TensorDescr):
)
elif isinstance(a.size, ParameterizedSize):
_ = a.size.validate_size(actual_size)
elif isinstance(a.size, DataDependentSize):
_ = a.size.validate_size(actual_size)
elif isinstance(a.size, SizeReference):
ref_tensor_axes = all_tensor_axes.get(a.size.tensor_id)
if ref_tensor_axes is None:
Expand Down Expand Up @@ -1728,6 +1770,18 @@ class LinkedModel(Node):
"""version number (n-th published version, not the semantic version) of linked model"""


class _DataDepSize(NamedTuple):
min: int
max: Optional[int]


class _TensorSizes(NamedTuple):
predetermined: Dict[Tuple[TensorId, AxisId], int]
"""size of axis (given `n` for `ParameterizedSize`)"""
data_dependent: Dict[Tuple[TensorId, AxisId], _DataDepSize]
"""min,max size of data dependent axis"""


class ModelDescr(GenericModelDescrBase, title="bioimage.io model specification"):
"""Specification of the fields used in a bioimage.io-compliant RDF to describe AI models with pretrained weights.
These fields are typically stored in a YAML file which we call a model resource description file (model RDF).
Expand Down Expand Up @@ -1826,7 +1880,7 @@ def _validate_axis(
if isinstance(axis, BatchAxis) or isinstance(axis.size, int):
return

if isinstance(axis.size, ParameterizedSize):
if isinstance(axis.size, (ParameterizedSize, DataDependentSize)):
if isinstance(axis, WithHalo) and (axis.size.min - 2 * axis.halo) < 1:
raise ValueError(
f"axis {axis.id} with minimum size {axis.size.min} is too small for"
Expand Down Expand Up @@ -2093,34 +2147,60 @@ def get_output_test_arrays(self) -> List[NDArray[Any]]:
return data

def get_tensor_sizes(
self, n: ParameterizedSize.N, batch_size: int
) -> Dict[TensorId, Dict[AxisId, int]]:
self, ns: Dict[Tuple[TensorId, AxisId], ParameterizedSize.N], batch_size: int
) -> _TensorSizes:
all_axes = {
t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
}

ret: Dict[TensorId, Dict[AxisId, int]] = {}
predetermined: Dict[Tuple[TensorId, AxisId], int] = {}
data_dependent: Dict[Tuple[TensorId, AxisId], _DataDepSize] = {}
for t_descr in chain(self.inputs, self.outputs):
ret[t_descr.id] = {}
for a in t_descr.axes:
if a.size is None:
assert isinstance(a, BatchAxis)
if isinstance(a, BatchAxis):
if (t_descr.id, a.id) in ns:
raise ValueError(
f"No size increment factor (n) for batch axis of tensor {t_descr.id} expected."
)
s = batch_size
elif isinstance(a.size, int):
if (t_descr.id, a.id) in ns:
raise ValueError(
f"No size increment factor (n) for fixed size axis {a.id} of tensor {t_descr.id} expected."
)
s = a.size
elif isinstance(a.size, ParameterizedSize):
s = a.size.get_size(n)
if (t_descr.id, a.id) not in ns:
raise ValueError(
f"Size increment factor (n) not given for parametrized axis {a.id} of tensor {t_descr.id}."
)
s = a.size.get_size(ns[(t_descr.id, a.id)])
elif isinstance(a.size, SizeReference):
assert not isinstance(a, BatchAxis)
ref_axis = all_axes[a.size.tensor_id][a.size.axis_id]
assert not isinstance(ref_axis, BatchAxis)
s = a.size.get_size(axis=a, ref_axis=ref_axis, n=n)
if (a.size.tensor_id, a.size.axis_id) not in ns:
raise ValueError(
f"No increment (n) provided for axis {a.id} of tensor {t_descr.id}. Expected reference from tensor {a.size.tensor_id} and axis {a.size.axis_id}."
)
s = a.size.get_size(
axis=a, ref_axis=ref_axis, n=ns[(t_descr.id, a.id)]
)
elif isinstance(a.size, DataDependentSize):
if (t_descr.id, a.id) in ns:
raise ValueError(
f"No size increment factor (n) for data dependent size axis {a.id} of tensor {t_descr.id} expected."
)
data_dependent[t_descr.id, a.id] = _DataDepSize(
a.size.min, a.size.max
)
continue
else:
assert_never(a.size)

ret[t_descr.id][a.id] = s
predetermined[t_descr.id, a.id] = s

return ret
return _TensorSizes(predetermined, data_dependent)

@model_validator(mode="before")
@classmethod
Expand Down
27 changes: 27 additions & 0 deletions tests/test_model/test_v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,33 @@ def test_output_fixed_shape_too_small(model_data: Dict[str, Any]):
assert summary.status == "failed", summary.format()


def test_get_tensor_sizes_raises_with_surplus_n(model_data: Dict[str, Any]):
with ValidationContext(perform_io_checks=False):
model = ModelDescr(**model_data)

output_tensor_id = model.inputs[0].id
output_axis_id = AxisId("y")

with pytest.raises(ValueError):
_ = model.get_tensor_sizes(
ns={(output_tensor_id, output_axis_id): 1}, batch_size=1
)


def test_get_tensor_sizes_raises_with_missing_n(model_data: Dict[str, Any]):
model_data["outputs"][0]["axes"][2] = {
"type": "space",
"id": "x",
"size": {"tensor_id": "input_1", "axis_id": "x"},
"halo": 0,
}

with ValidationContext(perform_io_checks=False):
model = ModelDescr(**model_data)
with pytest.raises(ValueError):
_ = model.get_tensor_sizes(ns={}, batch_size=1)


def test_output_ref_shape_mismatch(model_data: Dict[str, Any]):
model_data["outputs"][0]["axes"][2] = {
"type": "space",
Expand Down

0 comments on commit d8ce5d0

Please sign in to comment.