Skip to content

Commit

Permalink
Update log_sklearn_plot. (#4074)
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Oct 28, 2022
1 parent db49ca6 commit f1d9c10
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Live.log_plot()
# Live.log_sklearn_plot()

Generates a
[scikit learn plot](https://scikit-learn.org/stable/visualizations.html) and
saves the data in `{Live.dir}/plots/{name}.json`.
saves the data in `{Live.dir}/plots/sklearn/{name}.json`.

```py
def log_plot(self, name: str, labels, predictions, **kwargs):
def log_sklearn_plot(
self,
kind: Literal['calibration', 'confusion_matrix', 'precision_recall', 'roc'],
labels,
predictions,
name: Optional[str] = None,
**kwargs):
```

## Usage
Expand All @@ -18,40 +24,20 @@ live = Live()
y_true = [0, 0, 1, 1]
y_pred = [1, 0, 1, 0]
y_score = [0.1, 0.4, 0.35, 0.8]]
live.log_plot("roc", y_true, y_score)
live.log_plot("confusion_matrix", y_true, y_pred)
live.log_sklearn_plot("roc", y_true, y_score)
live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="cm.json")
```

## Description

Uses `name` to determine which plot should be generated. See
Uses `kind` to determine the type of plot to be generated. See
[supported plots](#supported-plots).

<admon type="tip">

The generated `{Live.dir}/plots/{name}.json` can be visualized with `dvc plots`.

</admon>

### Step updates

`Live.log_plot()` can be currently only used when `step` is `None`.

If you perform `step` updates in your code, you can later use
`Live.set_step(None)` in order to be able to use `Live.log_plot()`.

```python
for epoch in range(NUM_EPOCHS):
live.log_metric(metric_name, value)
live.next_step()

live.set_step(None)
live.log_plot("roc", y_true, y_score)
```
If `name` is not provided, `kind` will be used as the default name.

## Supported plots

`name` must be one of the supported plots:
`kind` must be one of the supported plots:

<toggle>

Expand All @@ -63,19 +49,19 @@ plot.

Calls
[sklearn.calibration.calibration_curve](https://scikit-learn.org/stable/modules/generated/sklearn.calibration.calibration_curve.html)
and stores the data at `{Live.dir}/plots/calibratrion.json` in a format
and stores the data at `{Live.dir}/plots/sklearn/calibratrion.json` in a format
compatible with `dvc plots`.

```py
y_true = [0, 0, 1, 1]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_plot("calibration", y_true, y_score)
live.log_sklearn_plot("calibration", y_true, y_score)
```

Example usage with `dvc plots`:

```dvc
$ dvc plots show 'dvclive/plots/calibration.json' \
$ dvc plots show 'dvclive/plots/sklearn/calibration.json' \
-x prob_pred -y prob_true \
--x-label 'Mean Predicted Probability' \
--y-label 'Fraction of Positives' \
Expand All @@ -91,21 +77,22 @@ $ dvc plots show 'dvclive/plots/calibration.json' \
Generates a [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix)
plot.

Stores the labels and predictions in `{Live.dir}/plots/confusion_matrix.json`,
with the format expected by the confusion matrix
Stores the labels and predictions in
`{Live.dir}/plots/sklearn/confusion_matrix.json`, with the format expected by
the confusion matrix
[template](/doc/user-guide/visualizing-plots#plot-templates-data-series-only) of
`dvc plots`.

```py
y_true = [1, 1, 2, 2]
y_pred = [2, 1, 1, 2]
live.log_plot("confusion_matrix", y_true, y_pred)
live.log_sklearn_plot("confusion_matrix", y_true, y_pred)
```

Example usage with `dvc plots`:

```dvc
$ dvc plots show 'dvclive/plots/confusion_matrix.json' \
$ dvc plots show 'dvclive/plots/sklearn/confusion_matrix.json' \
-x actual -y predicted \
--template confusion
```
Expand All @@ -122,19 +109,19 @@ plot.

Calls
[sklearn.metrics.det_curve](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.det_curve.html)
and stores the data at `{Live.dir}/plots/det.json` in a format compatible with
`dvc plots`.
and stores the data at `{Live.dir}/plots/sklearn/det.json` in a format
compatible with `dvc plots`.

```py
y_true = [1, 1, 2, 2]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_plot("det", y_true, y_score)
live.log_sklearn_plot("det", y_true, y_score)
```

Example usage with `dvc plots`:

```dvc
$ dvc plots show 'dvclive/plots/det.json' \
$ dvc plots show 'dvclive/plots/sklearn/det.json' \
-x fpr -y fnr \
--title 'DET Curve'
```
Expand All @@ -151,19 +138,19 @@ plot.

Calls
[sklearn.metrics.precision_recall_curve](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html)
and stores the data at `{Live.dir}/plots/precision_recall.json` in a format
compatible with `dvc plots`.
and stores the data at `{Live.dir}/plots/sklearn/precision_recall.json` in a
format compatible with `dvc plots`.

```py
y_true = [1, 1, 2, 2]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_plot("precision_recall", y_true, y_score)
live.log_sklearn_plot("precision_recall", y_true, y_score)
```

Example usage with `dvc plots`:

```dvc
$ dvc plots show 'dvclive/plots/precision_recall.json' \
$ dvc plots show 'dvclive/plots/sklearn/precision_recall.json' \
-x recall -y precision \
--title 'Precision Recall Curve'
```
Expand All @@ -180,19 +167,19 @@ plot.

Calls
[sklearn.metrics.roc_curve](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#sklearn.metrics.roc_curve)
and stores the data at `{Live.dir}/plots/roc.json` in a format compatible with
`dvc plots`.
and stores the data at `{Live.dir}/plots/sklearn/roc.json` in a format
compatible with `dvc plots`.

```py
y_true = [1, 1, 2, 2]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_plot("roc", y_true, y_score)
live.log_sklearn_plot("roc", y_true, y_score)
```

Example usage with `dvc plots`:

```dvc
$ dvc plots show 'dvclive/plots/roc.json' \
$ dvc plots show 'dvclive/plots/sklearn/roc.json' \
-x fpr -y tpr \
--title 'ROC Curve'
```
Expand All @@ -205,19 +192,20 @@ $ dvc plots show 'dvclive/plots/roc.json' \

## Parameters

- `name` - a [supported plot type](#supported-plots)
- `kind` - a [supported plot type](#supported-plots)

- `labels` - array of ground truth labels

- `predictions` - array of predicted labels (for `confusion_matrix`) or
predicted probabilities (for other plots)

- `name` - Optional name of the output file. If not provided, `kind` will be
used as name.

- `**kwargs` - additional arguments to be passed to the internal scikit-learn
function being called

## Exceptions

- `dvclive.error.InvalidPlotTypeError` - thrown if the provided `name` does not
- `dvclive.error.InvalidPlotTypeError` - thrown if the provided `kind` does not
correspond to any of the supported plots.

- `RuntimeError` - thrown if `Live.log_plot()` is used and `step` is not `None`.
4 changes: 2 additions & 2 deletions content/docs/sidebar.json
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,8 @@
"label": "log_params()"
},
{
"slug": "log_plot",
"label": "log_plot()"
"slug": "log_sklearn_plot",
"label": "log_sklearn_plot()"
},
{
"slug": "make_report",
Expand Down

0 comments on commit f1d9c10

Please sign in to comment.