From 33891bf6da333c0bdf33d550738cc78d47982357 Mon Sep 17 00:00:00 2001 From: Grvzard Date: Tue, 25 Jun 2024 21:53:25 +0800 Subject: [PATCH 1/3] Fix `export_lib.make_tensor_spec` --- keras/src/export/export_lib.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/keras/src/export/export_lib.py b/keras/src/export/export_lib.py index 02714c55c0a..be75c06cd3d 100644 --- a/keras/src/export/export_lib.py +++ b/keras/src/export/export_lib.py @@ -654,13 +654,18 @@ def make_tensor_spec(structure): # into plain Python structures because they don't work with jax2tf/JAX. if isinstance(structure, dict): return {k: make_tensor_spec(v) for k, v in structure.items()} - if isinstance(structure, (list, tuple)): + elif isinstance(structure, tuple): if all(isinstance(d, (int, type(None))) for d in structure): return tf.TensorSpec( shape=(None,) + structure[1:], dtype=model.input_dtype ) - result = [make_tensor_spec(v) for v in structure] - return tuple(result) if isinstance(structure, tuple) else result + return tuple(make_tensor_spec(v) for v in structure) + elif isinstance(structure, list): + if all(isinstance(d, (int, type(None))) for d in structure): + return tf.TensorSpec( + shape=[None] + structure[1:], dtype=model.input_dtype + ) + return [make_tensor_spec(v) for v in structure] else: raise ValueError( f"Unsupported type {type(structure)} for {structure}" From e9a4d86e62a9c0b730bfb96a2fbe54ff6bab44a1 Mon Sep 17 00:00:00 2001 From: Grvzard Date: Wed, 26 Jun 2024 02:01:18 +0800 Subject: [PATCH 2/3] Add test --- keras/src/export/export_lib_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/export_lib_test.py index 29504cfb2b1..811723cd9f3 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/export_lib_test.py @@ -196,6 +196,20 @@ def call(self, inputs): ) revived_model.serve(bigger_input) + # Test with keras.saving_lib + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model.keras") + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + export_lib.export_model(revived_model, self.get_temp_dir()) + def test_model_with_multiple_inputs(self): class TwoInputsModel(models.Model): From aca0141167d64bb67b38550c2f83f5a91048c86c Mon Sep 17 00:00:00 2001 From: Grvzard Date: Wed, 26 Jun 2024 02:07:12 +0800 Subject: [PATCH 3/3] chore(format) --- keras/src/export/export_lib_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/export_lib_test.py index 811723cd9f3..7b4b7d332dc 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/export_lib_test.py @@ -197,7 +197,9 @@ def call(self, inputs): revived_model.serve(bigger_input) # Test with keras.saving_lib - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model.keras") + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) saving_lib.save_model(model, temp_filepath) revived_model = saving_lib.load_model( temp_filepath,