-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensorflow 2.16 / Keras 3 support #2329
Comments
I've discovered that explicitly passing an input signature makes the function work again: import tf2onnx
import keras
inputs = [
keras.Input((128, 2), batch_size=64, dtype="float32", name="input_1"),
keras.Input((64, 4), batch_size=64, dtype="float32", name="input_2")
]
# [...] Rest of Keras code
model = keras.Model(inputs=inputs, outputs=outputs)
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=(
tf.TensorSpec(
model.inputs[0].shape,
dtype=model.inputs[0].dtype,
name=model.inputs[0].name,
),
tf.TensorSpec(
model.inputs[1].shape,
dtype=model.inputs[1].dtype,
name=model.inputs[1].name,
),
),
output_path="model.onnx",
) However, I do not know how robust this workaround is. python -m tf2onnx.convert --saved-model "path/to/saved_model_folder" --output "path/to/model.onnx" to convert it to the ONNX format. |
Is there any news of getting t2fonnx with keras models working on recent versions of Keras/Tensorflow? I've tried the workarounds suggested above and tf2onnx.convert.from_keras() gives me an error "Cannot convert a symbolic tf.Tensor (input_1:0) to a numpy array." and model.export() gives error "AttributeError: module 'keras._tf_keras.keras.backend' has no attribute 'set_learning_phase'." python 3.11 / tensorflow 2.16.1 / windows 10 |
This snippet works (at least for my use case) import keras
import tensorflow as tf
import tf2onnx
def _convert_to_onnx(source_path: str, destination_path: str):
model = keras.models.load_model(source_path)
input_tensor = model.layers[0]._input_tensor
input_signature = tf.TensorSpec(
name=input_tensor.name, shape=input_tensor.shape, dtype=input_tensor.dtype
)
output_name = model.layers[-1].name
@tf.function(input_signature=[input_signature])
def _wrapped_model(input_data):
return {output_name: model(input_data)}
tf2onnx.convert.from_function(
_wrapped_model, input_signature=[input_signature], output_path=destination_path
)
|
Does it work for anyone when using JAX as the backend? |
The switch from Keras 2 to Keras 3 in Tensorflow 2.16 apparently breaks
tf2onnx
:This is probably the same issue people are seeing with
tf.lite.TFLiteConverter
since Keras 3:keras-team/keras#18430
Is there an alternative route like
tf2onnx.convert.from_function
we could use as a workaround?The text was updated successfully, but these errors were encountered: