Skip to content

Commit

Permalink
Fix KerasFileEditor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 3, 2024
1 parent 726a38f commit 6ec0f46
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions keras/src/saving/file_editor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,55 +37,55 @@ def test_basics(self):

target_model = get_target_model()

out = editor.compare_to(model) # Succeeds
out = editor.compare(model) # Succeeds
self.assertEqual(out["status"], "success")
out = editor.compare_to(target_model) # Fails
out = editor.compare(target_model) # Fails

editor.add_object(
"layers/dense_3", weights={"0": np.random.random((3, 3))}
)
out = editor.compare_to(target_model) # Fails
out = editor.compare(target_model) # Fails
self.assertEqual(out["status"], "error")
self.assertEqual(out["error_count"], 2)

editor.rename_object("dense_3", "dense_4")
editor.rename_object("layers/dense_4", "dense_2")
editor.add_weights("dense_2", weights={"1": np.random.random((3,))})
out = editor.compare_to(target_model) # Succeeds
out = editor.compare(target_model) # Succeeds
self.assertEqual(out["status"], "success")

editor.add_object(
"layers/dense_3", weights={"0": np.random.random((3, 3))}
)
out = editor.compare_to(target_model) # Fails
out = editor.compare(target_model) # Fails
self.assertEqual(out["status"], "error")
self.assertEqual(out["error_count"], 1)

editor.delete_object("layers/dense_3")
out = editor.compare_to(target_model) # Succeeds
out = editor.compare(target_model) # Succeeds
self.assertEqual(out["status"], "success")
editor.summary()

temp_filepath = os.path.join(self.get_temp_dir(), "resaved.weights.h5")
editor.resave_weights(temp_filepath)
editor.save(temp_filepath)
target_model.load_weights(temp_filepath)

editor = KerasFileEditor(temp_filepath)
editor.summary()
out = editor.compare_to(target_model) # Succeeds
out = editor.compare(target_model) # Succeeds
self.assertEqual(out["status"], "success")

editor.delete_weight("dense_2", "1")
out = editor.compare_to(target_model) # Fails
out = editor.compare(target_model) # Fails
self.assertEqual(out["status"], "error")
self.assertEqual(out["error_count"], 1)

editor.add_weights("dense_2", {"1": np.zeros((7,))})
out = editor.compare_to(target_model) # Fails
out = editor.compare(target_model) # Fails
self.assertEqual(out["status"], "error")
self.assertEqual(out["error_count"], 1)

editor.delete_weight("dense_2", "1")
editor.add_weights("dense_2", {"1": np.zeros((3,))})
out = editor.compare_to(target_model) # Succeeds
out = editor.compare(target_model) # Succeeds
self.assertEqual(out["status"], "success")

0 comments on commit 6ec0f46

Please sign in to comment.