From bab2195c6fc6dba37b242d11f5ba6e3ae6613056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9rik=20Paradis?= Date: Fri, 2 Dec 2022 17:03:36 -0500 Subject: [PATCH] Add task arg to torchmetrics metrics in relevant examples --- docs/source/examples/semantic_segmentation.rst | 2 +- docs/source/metrics.rst | 18 +++++++++--------- examples/semantic_segmentation.ipynb | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/examples/semantic_segmentation.rst b/docs/source/examples/semantic_segmentation.rst index b8908be5..8387111f 100644 --- a/docs/source/examples/semantic_segmentation.rst +++ b/docs/source/examples/semantic_segmentation.rst @@ -252,7 +252,7 @@ For training, we use the Jaccard index metric in addition to the accuracy and F1 optimizer, criterion, batch_metrics=['accuracy'], - epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22)], + epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22, task="multiclass")], device=device, ) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 6654980b..017b571e 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -173,22 +173,22 @@ Examples: metric_collection = MetricCollection( [ - F1Score(num_classes=10, average='macro'), - Precision(num_classes=10, average='macro'), - Recall(num_classes=10, average='macro'), + F1Score(num_classes=10, average="macro", task="multiclass"), + Precision(num_classes=10, average="macro", task="multiclass"), + Recall(num_classes=10, average="macro", task="multiclass"), ] ) metrics = [ - ('custom_name', my_custom_metric), - (('metric_1', 'metric_2'), my_custom_metric2), - (('a', 'b'), my_custom_metric3), - (('metric_3', 'metric_4'), CustomMetric()), - (('c', 'd'), CustomMetric2()), + ("custom_name", my_custom_metric), + (("metric_1", "metric_2"), my_custom_metric2), + (("a", "b"), my_custom_metric3), + (("metric_3", "metric_4"), CustomMetric()), + (("c", "d"), CustomMetric2()), # No need to pass the names since the class sets the attribute __name__. CustomMetric3(), # The names are the keys returned by MetricCollection. - (('F1Score', 'Precision', 'Recall'), metric_collection), + (("F1Score", "Precision", "Recall"), metric_collection), ] diff --git a/examples/semantic_segmentation.ipynb b/examples/semantic_segmentation.ipynb index eccaae75..05d55cbb 100644 --- a/examples/semantic_segmentation.ipynb +++ b/examples/semantic_segmentation.ipynb @@ -990,7 +990,7 @@ " optimizer,\n", " criterion,\n", " batch_metrics=['accuracy'],\n", - " epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22)],\n", + " epoch_metrics=['f1', torchmetrics.JaccardIndex(num_classes=22, task=\"multiclass\")],\n", " device=device,\n", ")\n", "\n",