Skip to content

Commit

Permalink
Delay scikit-learn import until first use
Browse files Browse the repository at this point in the history
  • Loading branch information
wshanks committed Feb 28, 2023
1 parent 1d3894e commit 265ae4f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
31 changes: 7 additions & 24 deletions qiskit_experiments/data_processing/sklearn_discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@

"""Discriminators that wrap SKLearn."""

from typing import Any, List, Dict
from typing import Any, List, Dict, TYPE_CHECKING

from qiskit_experiments.data_processing.discriminator import BaseDiscriminator
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.warnings import require_sklearn

try:
if TYPE_CHECKING:
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class SkLDA(BaseDiscriminator):
"""A wrapper for the SKlearn linear discriminant analysis."""
Expand All @@ -40,11 +37,6 @@ def __init__(self, lda: "LinearDiscriminantAnalysis"):
Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._lda = lda
self.attributes = [
"coef_",
Expand Down Expand Up @@ -88,12 +80,10 @@ def config(self) -> Dict[str, Any]:
return {"params": self._lda.get_params(), "attributes": attr_conf}

@classmethod
@require_sklearn
def from_config(cls, config: Dict[str, Any]) -> "SkLDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis()
lda.set_params(**config["params"])

Expand All @@ -116,11 +106,6 @@ def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._qda = qda
self.attributes = [
"coef_",
Expand Down Expand Up @@ -165,12 +150,10 @@ def config(self) -> Dict[str, Any]:
return {"params": self._qda.get_params(), "attributes": attr_conf}

@classmethod
@require_sklearn
def from_config(cls, config: Dict[str, Any]) -> "SkQDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")

from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
qda = QuadraticDiscriminantAnalysis()
qda.set_params(**config["params"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

"""Multi state discrimination analysis."""

from typing import List, Tuple
from typing import List, Tuple, TYPE_CHECKING

import matplotlib
import numpy as np
Expand All @@ -22,14 +22,11 @@
from qiskit_experiments.data_processing import SkQDA
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle
from qiskit_experiments.warnings import require_sklearn

try:
if TYPE_CHECKING:
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class MultiStateDiscriminationAnalysis(BaseAnalysis):
r"""This class fits a multi-state discriminator to the data.
Expand All @@ -51,14 +48,10 @@ def __init__(self):
Raises:
DataProcessorError: if sklearn is not installed.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

super().__init__()

@classmethod
@require_sklearn
def _default_options(cls) -> Options:
"""Return default analysis options.
Expand All @@ -76,6 +69,7 @@ def _default_options(cls) -> Options:
)
options.plot = True
options.ax = None
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
options.discriminator = SkQDA(QuadraticDiscriminantAnalysis())
return options

Expand Down
24 changes: 24 additions & 0 deletions qiskit_experiments/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import warnings
from typing import Callable, Optional, Type, Dict

from qiskit.exceptions import QiskitError


def deprecated_function(
last_version: Optional[str] = None,
Expand Down Expand Up @@ -240,3 +242,25 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def require_sklearn(func: Callable) -> Callable:
"""Decorator to check that scikit-learn is installed before running a function"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
import sklearn # pylint: disable=unused-import
except ImportError as exc:
raise QiskitError(
f"{func.__qualname__} requires the scikit-learn package, but "
"sklearn could not be imported."
) from exc
except BaseException:
raise QiskitError(
f"{func.__qualname__} requires the scikit-learn package which "
"appears to be installed but sklearn could not be imported."
) from exc

return func(*args, **kwargs)

return wrapper
1 change: 0 additions & 1 deletion test/data_processing/test_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class TestDiscriminator(QiskitExperimentsTestCase):
@requires_sklearn
def test_lda_serialization(self):
"""Test the serialization of a lda."""

sk_lda = LinearDiscriminantAnalysis()
sk_lda.fit([[-1, 0], [1, 0], [-1.1, 0], [0.9, 0.1]], [0, 1, 0, 1])

Expand Down

0 comments on commit 265ae4f

Please sign in to comment.