Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
implement get_model for torch tensor (#296)
Browse files Browse the repository at this point in the history
fix dtype inspection for torch
  • Loading branch information
mike0sv authored Jun 16, 2022
1 parent c701297 commit d6bcbc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
1 change: 0 additions & 1 deletion mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def _subtype(self, subshape: Tuple[Optional[int], ...]):
)

def get_model(self, prefix: str = "") -> Type[BaseModel]:
# TODO: https://github.com/iterative/mlem/issues/33
return create_model(
prefix + "NumpyNdarray", __root__=(List[self._subtype(self.shape[1:])], ...) # type: ignore
)
Expand Down
25 changes: 22 additions & 3 deletions mlem/contrib/torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, ClassVar, Iterator, Optional, Tuple
from typing import Any, ClassVar, Iterator, List, Optional, Tuple

import torch
from pydantic import conlist, create_model

from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataHook,
Expand All @@ -17,6 +19,11 @@
from mlem.core.requirements import InstallableRequirement, Requirements


def python_type_from_torch_string_repr(dtype: str):
# not sure this will work all the time
return python_type_from_np_string_repr(dtype)


class TorchTensorDataType(
DataType, DataSerializer, DataHook, IsInstanceHookMixin
):
Expand Down Expand Up @@ -68,14 +75,26 @@ def get_writer(
) -> DataWriter:
return TorchTensorWriter(**kwargs)

def _subtype(self, subshape: Tuple[Optional[int], ...]):
if len(subshape) == 0:
return python_type_from_torch_string_repr(self.dtype)
return conlist(
self._subtype(subshape[1:]),
min_items=subshape[0],
max_items=subshape[0],
)

def get_model(self, prefix: str = ""):
raise NotImplementedError
return create_model(
prefix + "TorchTensor",
__root__=(List[self._subtype(self.shape[1:])], ...), # type: ignore
)

@classmethod
def process(cls, obj: torch.Tensor, **kwargs) -> DataType:
return TorchTensorDataType(
shape=(None,) + obj.shape[1:],
dtype=str(obj.dtype)[len(obj.dtype.__module__) + 1 :],
dtype=str(obj.dtype)[len("torch") + 1 :],
)


Expand Down

0 comments on commit d6bcbc6

Please sign in to comment.