diff --git a/keras_core/utils/summary_utils.py b/keras_core/utils/summary_utils.py index 8dc25e33d..e7e5ddf94 100644 --- a/keras_core/utils/summary_utils.py +++ b/keras_core/utils/summary_utils.py @@ -308,7 +308,7 @@ def print_layer(layer, nested_level=0): non_trainable_count = count_params(model.non_trainable_weights) non_trainable_memory_size = weight_memory_size(model.non_trainable_weights) - if model.compiled and model.optimizer.built: + if model.compiled and model.optimizer and model.optimizer.built: optimizer_weight_count = count_params(model.optimizer.variables) optimizer_memory_size = weight_memory_size(model.optimizer.variables) else: diff --git a/keras_core/utils/summary_utils_test.py b/keras_core/utils/summary_utils_test.py index 01798445f..83d98e0cf 100644 --- a/keras_core/utils/summary_utils_test.py +++ b/keras_core/utils/summary_utils_test.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from absl.testing import parameterized from keras_core import layers from keras_core import models @@ -9,34 +10,34 @@ from keras_core.utils import summary_utils -class SummaryUtilsTest(testing.TestCase): +class SummaryUtilsTest(testing.TestCase, parameterized.TestCase): @pytest.mark.requires_trainable_backend - def test_print_model_summary(self): + @parameterized.parameters([("adam",), (None,)]) + def test_print_model_summary(self, optimizer): inputs = layers.Input((2,)) outputs = layers.Dense(3)(inputs) model = models.Model(inputs, outputs) - model.compile(optimizer="adam", loss="mse", metrics=["mse"]) - # Trigger the optimizer weights creation - model.fit(x=np.zeros([4, 2]), y=np.zeros([4, 3])) + model.compile(optimizer=optimizer, loss="mse", metrics=["mse"]) + if optimizer: + # Trigger the optimizer weights creation + model.fit(x=np.zeros([4, 2]), y=np.zeros([4, 3])) - file_name = "model_1.txt" - temp_dir = self.get_temp_dir() - fpath = os.path.join(temp_dir, file_name) - writer = open(fpath, "w") + summary_content = [] - def print_to_file(text, line_break=False): - print(text, file=writer) + def print_to_variable(text, line_break=False): + summary_content.append(text) try: - summary_utils.print_summary(model, print_fn=print_to_file) - writer.close() - self.assertTrue(os.path.exists(fpath)) - with open(fpath, "r") as reader: - summary_content = reader.read() - # self.assertEqual(len(lines), 15) - self.assertIn("Total params: 29", summary_content) - self.assertIn("Trainable params: 9", summary_content) - self.assertIn("Non-trainable params: 0", summary_content) - self.assertIn("Optimizer params: 20", summary_content) + summary_utils.print_summary(model, print_fn=print_to_variable) + summary_content = "\n".join(summary_content) + if optimizer: + self.assertIn("Total params: 29", summary_content) + self.assertIn("Trainable params: 9", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) + self.assertIn("Optimizer params: 20", summary_content) + else: + self.assertIn("Total params: 9", summary_content) + self.assertIn("Trainable params: 9", summary_content) + self.assertIn("Non-trainable params: 0", summary_content) except ImportError: pass