Skip to content

Commit

Permalink
Fix summary when optimizer is None
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Sep 13, 2023
1 parent beba6f1 commit 91d489f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion keras_core/utils/summary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def print_layer(layer, nested_level=0):
non_trainable_count = count_params(model.non_trainable_weights)
non_trainable_memory_size = weight_memory_size(model.non_trainable_weights)

if model.compiled and model.optimizer.built:
if model.compiled and model.optimizer and model.optimizer.built:
optimizer_weight_count = count_params(model.optimizer.variables)
optimizer_memory_size = weight_memory_size(model.optimizer.variables)
else:
Expand Down
43 changes: 22 additions & 21 deletions keras_core/utils/summary_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,42 @@

import numpy as np
import pytest
from absl.testing import parameterized

from keras_core import layers
from keras_core import models
from keras_core import testing
from keras_core.utils import summary_utils


class SummaryUtilsTest(testing.TestCase):
class SummaryUtilsTest(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_print_model_summary(self):
@parameterized.parameters([("adam",), (None,)])
def test_print_model_summary(self, optimizer):
inputs = layers.Input((2,))
outputs = layers.Dense(3)(inputs)
model = models.Model(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mse"])
# Trigger the optimizer weights creation
model.fit(x=np.zeros([4, 2]), y=np.zeros([4, 3]))
model.compile(optimizer=optimizer, loss="mse", metrics=["mse"])
if optimizer:
# Trigger the optimizer weights creation
model.fit(x=np.zeros([4, 2]), y=np.zeros([4, 3]))

file_name = "model_1.txt"
temp_dir = self.get_temp_dir()
fpath = os.path.join(temp_dir, file_name)
writer = open(fpath, "w")
summary_content = []

def print_to_file(text, line_break=False):
print(text, file=writer)
def print_to_variable(text, line_break=False):
summary_content.append(text)

try:
summary_utils.print_summary(model, print_fn=print_to_file)
writer.close()
self.assertTrue(os.path.exists(fpath))
with open(fpath, "r") as reader:
summary_content = reader.read()
# self.assertEqual(len(lines), 15)
self.assertIn("Total params: 29", summary_content)
self.assertIn("Trainable params: 9", summary_content)
self.assertIn("Non-trainable params: 0", summary_content)
self.assertIn("Optimizer params: 20", summary_content)
summary_utils.print_summary(model, print_fn=print_to_variable)
summary_content = "\n".join(summary_content)
if optimizer:
self.assertIn("Total params: 29", summary_content)
self.assertIn("Trainable params: 9", summary_content)
self.assertIn("Non-trainable params: 0", summary_content)
self.assertIn("Optimizer params: 20", summary_content)
else:
self.assertIn("Total params: 9", summary_content)
self.assertIn("Trainable params: 9", summary_content)
self.assertIn("Non-trainable params: 0", summary_content)
except ImportError:
pass

0 comments on commit 91d489f

Please sign in to comment.