From a43de1af3a15a995004ec92e6160075166f2b27f Mon Sep 17 00:00:00 2001 From: zhanglirong1999 <1695074375@qq.com> Date: Wed, 17 Aug 2022 19:00:38 +0800 Subject: [PATCH] back commit --- .../src/bigdl/orca/learn/pytorch/pytorch_metrics.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_metrics.py b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_metrics.py index 5312fa016e8..b339f316541 100644 --- a/python/orca/src/bigdl/orca/learn/pytorch/pytorch_metrics.py +++ b/python/orca/src/bigdl/orca/learn/pytorch/pytorch_metrics.py @@ -15,13 +15,6 @@ # import torch from bigdl.dllib.utils.log4Error import invalidInputError - -try: - import torchmetrics -except ImportError: - invalidInputError(False, - "please install torchmetrics: pip install torchmetrics") - from abc import ABC, abstractmethod @@ -466,6 +459,7 @@ class AUROC(PytorchMetric): ``` """ def __init__(self): + import torchmetrics self.internal_auc = torchmetrics.AUROC() def __call__(self, preds, targets): @@ -497,6 +491,7 @@ class ROC(PytorchMetric): """ def __init__(self): + import torchmetrics self.internal_roc = torchmetrics.ROC() def __call__(self, preds, targets): @@ -521,6 +516,7 @@ class F1Score(PytorchMetric): """ def __init__(self): + import torchmetrics self.internal_f1 = torchmetrics.F1Score() def __call__(self, preds, targets): @@ -546,6 +542,7 @@ class Precision(PytorchMetric): """ def __init__(self): + import torchmetrics self.internal_precision = torchmetrics.Precision() def __call__(self, preds, targets): @@ -571,6 +568,7 @@ class Recall(PytorchMetric): """ def __init__(self): + import torchmetrics self.internal_recall = torchmetrics.Recall() def __call__(self, preds, targets): @@ -602,6 +600,7 @@ class PrecisionRecallCurve(PytorchMetric): """ def __init__(self): + import torchmetrics self.internal_curve = torchmetrics.PrecisionRecallCurve() def __call__(self, preds, targets):