Skip to content

Commit

Permalink
Merge pull request #201 from antoinejeannot/fix-MRO
Browse files Browse the repository at this point in the history
tensorflow model: fix TensorflowModelMixin usage
  • Loading branch information
antoinejeannot authored Nov 13, 2023
2 parents 0c48001 + 8128fee commit 139af52
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
5 changes: 3 additions & 2 deletions modelkit/core/models/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs):
raise ImportError(
"Tensorflow is not installed, instal modelkit[tensorflow]."
)
super().__init__(*args, **kwargs)

def _is_empty(self, item) -> bool:
return False
Expand Down Expand Up @@ -87,7 +88,7 @@ def _rebuild_predictions_with_mask(
return results


class TensorflowModel(Model[ItemType, ReturnType], TensorflowModelMixin):
class TensorflowModel(TensorflowModelMixin, Model[ItemType, ReturnType]):
def __init__(
self,
output_tensor_mapping: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -278,7 +279,7 @@ def close(self):
return self.requests_session.close()


class AsyncTensorflowModel(AsyncModel[ItemType, ReturnType], TensorflowModelMixin):
class AsyncTensorflowModel(TensorflowModelMixin, AsyncModel[ItemType, ReturnType]):
def __init__(
self,
output_tensor_mapping: Optional[Dict[str, str]] = None,
Expand Down
33 changes: 19 additions & 14 deletions tests/test_tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,17 @@
import requests

from modelkit import ModelLibrary, testing
from modelkit.core.models.tensorflow_model import (
AsyncTensorflowModel,
TensorflowModel,
has_tensorflow,
)
from modelkit.core.settings import LibrarySettings
from modelkit.testing import tf_serving_fixture
from modelkit.utils.tensorflow import write_config
from tests import TEST_DIR
from tests.conftest import skip_unless

try:
from modelkit.core.models.tensorflow_model import (
AsyncTensorflowModel,
TensorflowModel,
)
from modelkit.testing import tf_serving_fixture
from modelkit.utils.tensorflow import write_config
except NameError:
# This occurs because type annotations in
# modelkit.core.models.tensorflow_model will raise
# `NameError: name 'prediction_service_pb2_grpc' is not defined`
# when tensorflow-serving-api is not installed
pass

np = pytest.importorskip("numpy")
grpc = pytest.importorskip("grpc")

Expand Down Expand Up @@ -434,3 +427,15 @@ def test_iso_serving_mode_no_serving(dummy_tf_models, monkeypatch, working_dir):
),
models=dummy_tf_models,
)


@pytest.mark.skipif(
has_tensorflow,
reason="This test needs not Tensorflow to be run",
)
@pytest.mark.parametrize("ModelType", [AsyncTensorflowModel, TensorflowModel])
def test_tf_model_with_mixin_mro(ModelType):
# TensorflowModelMixin raises an ImportError when inherited the right way
# if tensorflow is not installed
with pytest.raises(ImportError):
ModelType()

0 comments on commit 139af52

Please sign in to comment.