Skip to content

Commit

Permalink
Layout Inference (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Oct 4, 2024
1 parent 716884f commit f11d311
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 3 deletions.
4 changes: 2 additions & 2 deletions luxonis_ml/nn_archive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .archive_generator import ArchiveGenerator
from .config import Config
from .model import Model
from .utils import is_nn_archive
from .utils import infer_layout, is_nn_archive

__all__ = ["ArchiveGenerator", "Model", "Config", "is_nn_archive"]
__all__ = ["ArchiveGenerator", "Model", "Config", "is_nn_archive", "infer_layout"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import List, Optional
from contextlib import suppress
from typing import Any, Dict, List, Optional

from pydantic import Field, model_validator
from typing_extensions import Self

from luxonis_ml.utils import BaseModelExtraForbid

from ...utils import infer_layout
from ..enums import DataType, InputType


Expand Down Expand Up @@ -97,3 +99,11 @@ def validate_layout(self) -> Self:
)

return self

@model_validator(mode="before")
@staticmethod
def infer_layout(data: Dict[str, Any]) -> Dict[str, Any]:
if "shape" in data and "layout" not in data:
with suppress(Exception):
data["layout"] = infer_layout(data["shape"])
return data
10 changes: 10 additions & 0 deletions luxonis_ml/nn_archive/config_building_blocks/base_models/output.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from contextlib import suppress
from typing import List, Optional

from pydantic import Field, model_validator
from typing_extensions import Self

from luxonis_ml.utils import BaseModelExtraForbid

from ...utils import infer_layout
from ..enums import DataType


Expand Down Expand Up @@ -49,3 +51,11 @@ def validate_layout(self) -> Self:
raise ValueError("Layout and shape must have the same length.")

return self

@model_validator(mode="after")
def infer_layout(self) -> Self:
if self.layout is None and self.shape is not None:
with suppress(Exception):
self.layout = infer_layout(self.shape)

return self
38 changes: 38 additions & 0 deletions luxonis_ml/nn_archive/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tarfile
from pathlib import Path
from typing import List

from luxonis_ml.utils.filesystem import PathType

Expand All @@ -25,3 +26,40 @@ def is_nn_archive(path: PathType) -> bool:
return False

return True


def infer_layout(shape: List[int]) -> str:
"""Infers a layout for the given shape.
Tries to guess most common layouts for the given shape pattern.
Otherwise, uses the first free letter of the alphabet for each dimension.
Example::
>>> make_default_layout([1, 3, 256, 256])
>>> "NCHW"
>>> make_default_layout([1, 19, 7, 8])
>>> "NABC"
"""
layout = []
i = 0
if shape[0] == 1:
layout.append("N")
i += 1
if len(shape) - i == 3:
if shape[i] < shape[i + 1] and shape[i] < shape[i + 2]:
return "".join(layout + ["C", "H", "W"])
elif shape[-1] < shape[-2] and shape[-1] < shape[-3]:
return "".join(layout + ["H", "W", "C"])
i = 0
while len(layout) < len(shape):
# Starting with "C" for more sensible defaults
letter = chr(ord("A") + i + 2)
if ord(letter) > ord("Z"):
raise ValueError(
f"Too many dimensions ({len(shape)}) for automatic layout."
)

if letter not in layout:
layout.append(letter)
i += 1
return "".join(layout)
16 changes: 16 additions & 0 deletions tests/test_nn_archive/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from luxonis_ml.nn_archive.utils import infer_layout


def test_infer_layout():
assert infer_layout([1, 3, 256, 256]) == "NCHW"
assert infer_layout([1, 1, 256, 256]) == "NCHW"
assert infer_layout([1, 4, 256, 256]) == "NCHW"
assert infer_layout([1, 19, 7, 8]) == "NCDE"
assert infer_layout([256, 256, 3]) == "HWC"
assert infer_layout([256, 256, 1]) == "HWC"
assert infer_layout([256, 256, 12]) == "HWC"

with pytest.raises(ValueError):
infer_layout(list(range(30)))

0 comments on commit f11d311

Please sign in to comment.