Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

guide: revisit Checkpoints guide(s) #2753

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions content/docs/user-guide/experiment-management/checkpoints.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,140 @@
# Checkpoints

<!--
## Checkpoints

To track successive steps in a longer or deeper <abbr>experiment</abbr>, you can
register checkpoints from your code. Each `dvc exp run` will resume from the
last checkpoint.

Checkpoints provide a way to train models iteratively, keeping the metrics,
models and other artifacts associated with each epoch.

### Adding checkpoints to the pipeline

There are various ways to add checkpoints to a project. In common, these all
involve marking a stage <abbr>output</abbr> with `checkpoint: true` in
`dvc.yaml`. This is needed so that the experiment can resume later, based on the
<abbr>cached</abbr> output(s).

If you are adding a new stage with `dvc stage add`, you can mark its output(s)
with `--checkpoints` (`-c`) option. DVC will add a `checkpoint: true` to the
stage output in `dvc.yaml`.

Otherwise, if you are adding a checkpoint to an already existing project, you
can edit `dvc.yaml` and add a `checkpoint: true` to the stage output as shown
below:

```yaml
stages:
...
train:
...
outs:
- model.pt:
checkpoint: true
...
```

### Adding checkpoints to Python code

DVC is agnostic when it comes to the language you use in your project.
Checkpoints are basically a mechanism to associate outputs of a pipeline with
its metrics. Reading the model from previous iteration and writing a new model
as a file are not handled by DVC. DVC captures the signal produced by the
machine learning experimentation code and stores each successive checkpoint.

> 💡 DVC provides several automated ways to capture checkpoints for popular ML
> libraries in [DVClive](https://dvc.org/doc/dvclive). It may be more productive
> to use checkpoints via DVClive. Here we discuss adding checkpoints to a
> project manually.

If you are writing the project in Python, the easiest way to signal DVC to
capture the checkpoint is to use `dvc.api.make_checkpoint()` function. It
creates a checkpoint and records all artifacts changed after the previous
checkpoint as another experiment.

The following snippet shows an example that uses a Keras custom callback class.
The callback signals DVC to create a checkpoint at the end of each checkpoint.

```python
class DVCCheckpointsCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
dvc.api.make_checkpoint()
...

history = model.fit(
...
callbacks=[DVCCheckpointsCallback(), ...]
)
```

A similar approach can be taken in PyTorch when using a loop to train a model:

```python
for epoch in range(1, EPOCHS+1):
...
for x_batch, y_batch in train_loader:
train(model, x_batch, y_batch)
torch.save(model.state_dict(), "model.pt")
# Evaluate and checkpoint.
evaluate(model, x_test, y_test)
dvc.api.make_checkpoint()
...
```

Even if you're not using these libraries, you can use checkpoints in your
project at each epoch/step by first recording all intermediate artifacts and
metrics, then calling `dvc.api.make_checkpoint()`.

### Adding checkpoints to non-Python code

If you use another language in your project, you can mimic the behavior of
`make_checkpoint`. In essence `make_checkpoint` creates a special file named
`DVC_CHECKPOINT` inside `.dvc/tmp/` to signal DVC, and waits the file to be
removed.

```r

dvcroot <- Sys.getenv("DVC_ROOT")

if (dvcroot != "") {
signalfilepath = file.path(dvcroot, ".dvc", "tmp", "DVC_CHECKPOINT")
file.create(signalfilepath)
while (file.exists(signalfilepath)) {
Sys.sleep(0.01)
}

}

```

The following Julia snippet creates a signal file to create a checkpoint.

```julia

dvc_root = get(ENV, "DVC_ROOT", "")

if dvc_root != ""
signal_file_path = joinpath(dvc_root, ".dvc", "tmp", "DVC_CHECKPOINT")
open(signal_file_path, "w") do io
write(io, "")
end;
while isfile(signal_file_path)
sleep()
end;
```

<details>

### How are checkpoints captured?

Instead of a single commit, checkpoint experiments have multiple commits under
the custom Git reference (in `.git/refs/exps`), similar to a branch.

</details>
-->

ML checkpoints are an important part of deep learning because ML engineers like
to save the model files at certain points during a training process.

Expand Down