diff --git a/neural_compressor/model/tensorflow_model.py b/neural_compressor/model/tensorflow_model.py index e4809863a55..0caadc53dbd 100644 --- a/neural_compressor/model/tensorflow_model.py +++ b/neural_compressor/model/tensorflow_model.py @@ -81,7 +81,9 @@ def get_model_type(model): return "graph" elif isinstance(model, tf.compat.v1.GraphDef): return "graph_def" - elif isinstance(model, tf.compat.v1.estimator.Estimator): + elif not version1_gte_version2(tf.version.VERSION, "2.16.1") and isinstance( + model, tf.compat.v1.estimator.Estimator + ): return "estimator" elif isinstance(model, str): model = os.path.abspath(os.path.expanduser(model))