From e45b77138a98380638f75aa3397c7150959fca5b Mon Sep 17 00:00:00 2001 From: Antoine Jeannot Date: Wed, 11 Oct 2023 17:13:12 +0200 Subject: [PATCH 1/2] tensorflow model: fix TensorflowModelMixin usage --- modelkit/core/models/tensorflow_model.py | 7 +++-- tests/test_tf_model.py | 33 ++++++++++++++---------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/modelkit/core/models/tensorflow_model.py b/modelkit/core/models/tensorflow_model.py index 9237d5ec..4f6b4020 100644 --- a/modelkit/core/models/tensorflow_model.py +++ b/modelkit/core/models/tensorflow_model.py @@ -57,6 +57,9 @@ def __init__(self, *args, **kwargs): raise ImportError( "Tensorflow is not installed, instal modelkit[tensorflow]." ) + # this mixing implementing __init__ + # we need to call super().__init__ to call the Model.__init__ + super().__init__(*args, **kwargs) def _is_empty(self, item) -> bool: return False @@ -87,7 +90,7 @@ def _rebuild_predictions_with_mask( return results -class TensorflowModel(Model[ItemType, ReturnType], TensorflowModelMixin): +class TensorflowModel(TensorflowModelMixin, Model[ItemType, ReturnType]): def __init__( self, output_tensor_mapping: Optional[Dict[str, str]] = None, @@ -278,7 +281,7 @@ def close(self): return self.requests_session.close() -class AsyncTensorflowModel(AsyncModel[ItemType, ReturnType], TensorflowModelMixin): +class AsyncTensorflowModel(TensorflowModelMixin, AsyncModel[ItemType, ReturnType]): def __init__( self, output_tensor_mapping: Optional[Dict[str, str]] = None, diff --git a/tests/test_tf_model.py b/tests/test_tf_model.py index 3ed4905e..9457375a 100644 --- a/tests/test_tf_model.py +++ b/tests/test_tf_model.py @@ -4,24 +4,17 @@ import requests from modelkit import ModelLibrary, testing +from modelkit.core.models.tensorflow_model import ( + AsyncTensorflowModel, + TensorflowModel, + has_tensorflow, +) from modelkit.core.settings import LibrarySettings +from modelkit.testing import tf_serving_fixture +from modelkit.utils.tensorflow import write_config from tests import TEST_DIR from tests.conftest import skip_unless -try: - from modelkit.core.models.tensorflow_model import ( - AsyncTensorflowModel, - TensorflowModel, - ) - from modelkit.testing import tf_serving_fixture - from modelkit.utils.tensorflow import write_config -except NameError: - # This occurs because type annotations in - # modelkit.core.models.tensorflow_model will raise - # `NameError: name 'prediction_service_pb2_grpc' is not defined` - # when tensorflow-serving-api is not installed - pass - np = pytest.importorskip("numpy") grpc = pytest.importorskip("grpc") @@ -434,3 +427,15 @@ def test_iso_serving_mode_no_serving(dummy_tf_models, monkeypatch, working_dir): ), models=dummy_tf_models, ) + + +@pytest.mark.skipif( + has_tensorflow, + reason="This test needs not Tensorflow to be run", +) +@pytest.mark.parametrize("ModelType", [AsyncTensorflowModel, TensorflowModel]) +def test_tf_model_with_mixin_mro(ModelType): + # TensorflowModelMixin raises an ImportError when inherited the right way + # if tensorflow is not installed + with pytest.raises(ImportError): + ModelType() From 8128fee8110a843a3059b86a2482beac4865f950 Mon Sep 17 00:00:00 2001 From: Antoine Jeannot Date: Thu, 12 Oct 2023 13:06:12 +0200 Subject: [PATCH 2/2] tensorflow model: remove unhelpful comment --- modelkit/core/models/tensorflow_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modelkit/core/models/tensorflow_model.py b/modelkit/core/models/tensorflow_model.py index 4f6b4020..6d786c6c 100644 --- a/modelkit/core/models/tensorflow_model.py +++ b/modelkit/core/models/tensorflow_model.py @@ -57,8 +57,6 @@ def __init__(self, *args, **kwargs): raise ImportError( "Tensorflow is not installed, instal modelkit[tensorflow]." ) - # this mixing implementing __init__ - # we need to call super().__init__ to call the Model.__init__ super().__init__(*args, **kwargs) def _is_empty(self, item) -> bool: