Skip to content

Commit

Permalink
keras: Use log_artifact.
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Mar 16, 2023
1 parent d66edbd commit 9a48a9c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/dvclive/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def on_epoch_end(
self.model.save_weights(self.model_file)
else:
self.model.save(self.model_file)
self.live.log_artifact(self.model_file)
self.live.next_step()

def on_train_end(
Expand Down
9 changes: 6 additions & 3 deletions tests/test_frameworks/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,20 @@ def test_keras_model_file(tmp_dir, xor_model, mocker, save_weights_only, capture
save = mocker.spy(model, "save")
save_weights = mocker.spy(model, "save_weights")

live_callback = DVCLiveCallback(
model_file="model.h5", save_weights_only=save_weights_only
)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")
model.fit(
x,
y,
epochs=1,
batch_size=1,
callbacks=[
DVCLiveCallback(model_file="model.h5", save_weights_only=save_weights_only)
],
callbacks=[live_callback],
)
assert save.call_count != save_weights_only
assert save_weights.call_count == save_weights_only
log_artifact.assert_called_with(live_callback.model_file)


@pytest.mark.parametrize("save_weights_only", (True, False))
Expand Down

0 comments on commit 9a48a9c

Please sign in to comment.