diff --git a/modelkit/core/models/tensorflow_model.py b/modelkit/core/models/tensorflow_model.py index 9237d5ec..6d786c6c 100644 --- a/modelkit/core/models/tensorflow_model.py +++ b/modelkit/core/models/tensorflow_model.py @@ -57,6 +57,7 @@ def __init__(self, *args, **kwargs): raise ImportError( "Tensorflow is not installed, instal modelkit[tensorflow]." ) + super().__init__(*args, **kwargs) def _is_empty(self, item) -> bool: return False @@ -87,7 +88,7 @@ def _rebuild_predictions_with_mask( return results -class TensorflowModel(Model[ItemType, ReturnType], TensorflowModelMixin): +class TensorflowModel(TensorflowModelMixin, Model[ItemType, ReturnType]): def __init__( self, output_tensor_mapping: Optional[Dict[str, str]] = None, @@ -278,7 +279,7 @@ def close(self): return self.requests_session.close() -class AsyncTensorflowModel(AsyncModel[ItemType, ReturnType], TensorflowModelMixin): +class AsyncTensorflowModel(TensorflowModelMixin, AsyncModel[ItemType, ReturnType]): def __init__( self, output_tensor_mapping: Optional[Dict[str, str]] = None, diff --git a/tests/test_tf_model.py b/tests/test_tf_model.py index 3ed4905e..9457375a 100644 --- a/tests/test_tf_model.py +++ b/tests/test_tf_model.py @@ -4,24 +4,17 @@ import requests from modelkit import ModelLibrary, testing +from modelkit.core.models.tensorflow_model import ( + AsyncTensorflowModel, + TensorflowModel, + has_tensorflow, +) from modelkit.core.settings import LibrarySettings +from modelkit.testing import tf_serving_fixture +from modelkit.utils.tensorflow import write_config from tests import TEST_DIR from tests.conftest import skip_unless -try: - from modelkit.core.models.tensorflow_model import ( - AsyncTensorflowModel, - TensorflowModel, - ) - from modelkit.testing import tf_serving_fixture - from modelkit.utils.tensorflow import write_config -except NameError: - # This occurs because type annotations in - # modelkit.core.models.tensorflow_model will raise - # `NameError: name 'prediction_service_pb2_grpc' is not defined` - # when tensorflow-serving-api is not installed - pass - np = pytest.importorskip("numpy") grpc = pytest.importorskip("grpc") @@ -434,3 +427,15 @@ def test_iso_serving_mode_no_serving(dummy_tf_models, monkeypatch, working_dir): ), models=dummy_tf_models, ) + + +@pytest.mark.skipif( + has_tensorflow, + reason="This test needs not Tensorflow to be run", +) +@pytest.mark.parametrize("ModelType", [AsyncTensorflowModel, TensorflowModel]) +def test_tf_model_with_mixin_mro(ModelType): + # TensorflowModelMixin raises an ImportError when inherited the right way + # if tensorflow is not installed + with pytest.raises(ImportError): + ModelType()