Skip to content

Commit

Permalink
Add task arg to torchmetrics metrics in relevant examples
Browse files Browse the repository at this point in the history
  • Loading branch information
freud14 committed Dec 2, 2022
1 parent 1130501 commit bab2195
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples/semantic_segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ For training, we use the Jaccard index metric in addition to the accuracy and F1
optimizer,
criterion,
batch_metrics=['accuracy'],
epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22)],
epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22, task="multiclass")],
device=device,
)
Expand Down
18 changes: 9 additions & 9 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,22 @@ Examples:
metric_collection = MetricCollection(
[
F1Score(num_classes=10, average='macro'),
Precision(num_classes=10, average='macro'),
Recall(num_classes=10, average='macro'),
F1Score(num_classes=10, average="macro", task="multiclass"),
Precision(num_classes=10, average="macro", task="multiclass"),
Recall(num_classes=10, average="macro", task="multiclass"),
]
)
metrics = [
('custom_name', my_custom_metric),
(('metric_1', 'metric_2'), my_custom_metric2),
(('a', 'b'), my_custom_metric3),
(('metric_3', 'metric_4'), CustomMetric()),
(('c', 'd'), CustomMetric2()),
("custom_name", my_custom_metric),
(("metric_1", "metric_2"), my_custom_metric2),
(("a", "b"), my_custom_metric3),
(("metric_3", "metric_4"), CustomMetric()),
(("c", "d"), CustomMetric2()),
# No need to pass the names since the class sets the attribute __name__.
CustomMetric3(),
# The names are the keys returned by MetricCollection.
(('F1Score', 'Precision', 'Recall'), metric_collection),
(("F1Score", "Precision", "Recall"), metric_collection),
]
2 changes: 1 addition & 1 deletion examples/semantic_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@
" optimizer,\n",
" criterion,\n",
" batch_metrics=['accuracy'],\n",
" epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22)],\n",
" epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22, task=\"multiclass\")],\n",
" device=device,\n",
")\n",
"\n",
Expand Down

0 comments on commit bab2195

Please sign in to comment.