diff --git a/dev-requirements.in b/dev-requirements.in index 313ce1d82b0..295b74c7247 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -13,3 +13,4 @@ google-cloud-bigquery google-cloud-bigquery-storage IPython torch +tensorflow diff --git a/flytekit/extras/keras/__init__.py b/flytekit/extras/keras/__init__.py index d0b5964e6b9..ea1adf8a5da 100644 --- a/flytekit/extras/keras/__init__.py +++ b/flytekit/extras/keras/__init__.py @@ -13,7 +13,7 @@ # that have soft dependencies try: # isolate the exception to the keras import - import keras + from tensorflow import keras _keras_installed = True except (ImportError, OSError):