Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Fix torch serving and saving (again) #296

Merged
merged 1 commit into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def _subtype(self, subshape: Tuple[Optional[int], ...]):
)

def get_model(self, prefix: str = "") -> Type[BaseModel]:
# TODO: https://github.com/iterative/mlem/issues/33
return create_model(
prefix + "NumpyNdarray", __root__=(List[self._subtype(self.shape[1:])], ...) # type: ignore
)
Expand Down
25 changes: 22 additions & 3 deletions mlem/contrib/torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, ClassVar, Iterator, Optional, Tuple
from typing import Any, ClassVar, Iterator, List, Optional, Tuple

import torch
from pydantic import conlist, create_model

from mlem.constants import PREDICT_METHOD_NAME
from mlem.contrib.numpy import python_type_from_np_string_repr
from mlem.core.artifacts import Artifacts, Storage
from mlem.core.data_type import (
DataHook,
Expand All @@ -17,6 +19,11 @@
from mlem.core.requirements import InstallableRequirement, Requirements


def python_type_from_torch_string_repr(dtype: str):
# not sure this will work all the time
return python_type_from_np_string_repr(dtype)


class TorchTensorDataType(
DataType, DataSerializer, DataHook, IsInstanceHookMixin
):
Expand Down Expand Up @@ -68,14 +75,26 @@ def get_writer(
) -> DataWriter:
return TorchTensorWriter(**kwargs)

def _subtype(self, subshape: Tuple[Optional[int], ...]):
if len(subshape) == 0:
return python_type_from_torch_string_repr(self.dtype)
return conlist(
self._subtype(subshape[1:]),
min_items=subshape[0],
max_items=subshape[0],
)

def get_model(self, prefix: str = ""):
raise NotImplementedError
return create_model(
prefix + "TorchTensor",
__root__=(List[self._subtype(self.shape[1:])], ...), # type: ignore
)

@classmethod
def process(cls, obj: torch.Tensor, **kwargs) -> DataType:
return TorchTensorDataType(
shape=(None,) + obj.shape[1:],
dtype=str(obj.dtype)[len(obj.dtype.__module__) + 1 :],
dtype=str(obj.dtype)[len("torch") + 1 :],
)


Expand Down