Skip to content

Commit

Permalink
Delay scikit-learn import until first use (#1061)
Browse files Browse the repository at this point in the history
### Summary

This commit pushes the import of scikit-learn lower into the code, so
that it does not occur until first usage.

### Details and comments

Previously, scikit-learn was in a `try` block which allowed for code
that did not use scikit-learn to work fine when it was not installed.
However, sometimes scikit-learn can be installed but have errors (like
in #1050) which were not caught by the `try` block. Further delaying the
import can help in this case. Additionally, scikit-learn is a little bit
of a slow import, so not importing it when it is not needed gives a
little bit of efficiency (maybe; mostly it imports scipy modules but
those might get imported any way by other analysis code).

---------

Co-authored-by: Helena Zhang <[email protected]>
  • Loading branch information
wshanks and coruscating authored Mar 3, 2023
1 parent fd9f65e commit 0a9da6a
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 168 deletions.
42 changes: 17 additions & 25 deletions qiskit_experiments/data_processing/sklearn_discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@

"""Discriminators that wrap SKLearn."""

from typing import Any, List, Dict
from typing import Any, List, Dict, TYPE_CHECKING

from qiskit_experiments.data_processing.discriminator import BaseDiscriminator
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.warnings import HAS_SKLEARN

try:
if TYPE_CHECKING:
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class SkLDA(BaseDiscriminator):
"""A wrapper for the SKlearn linear discriminant analysis."""
"""A wrapper for the scikit-learn linear discriminant analysis.
.. note::
This class requires that scikit-learn is installed.
"""

def __init__(self, lda: "LinearDiscriminantAnalysis"):
"""
Expand All @@ -40,11 +40,6 @@ def __init__(self, lda: "LinearDiscriminantAnalysis"):
Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._lda = lda
self.attributes = [
"coef_",
Expand Down Expand Up @@ -88,11 +83,10 @@ def config(self) -> Dict[str, Any]:
return {"params": self._lda.get_params(), "attributes": attr_conf}

@classmethod
@HAS_SKLEARN.require_in_call
def from_config(cls, config: Dict[str, Any]) -> "SkLDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda.set_params(**config["params"])
Expand All @@ -105,7 +99,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SkLDA":


class SkQDA(BaseDiscriminator):
"""A wrapper for the SKlearn quadratic discriminant analysis."""
"""A wrapper for the SKlearn quadratic discriminant analysis.
.. note::
This class requires that scikit-learn is installed.
"""

def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
"""
Expand All @@ -116,11 +114,6 @@ def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._qda = qda
self.attributes = [
"coef_",
Expand Down Expand Up @@ -165,11 +158,10 @@ def config(self) -> Dict[str, Any]:
return {"params": self._qda.get_params(), "attributes": attr_conf}

@classmethod
@HAS_SKLEARN.require_in_call
def from_config(cls, config: Dict[str, Any]) -> "SkQDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

qda = QuadraticDiscriminantAnalysis()
qda.set_params(**config["params"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@

"""Multi state discrimination analysis."""

from typing import List, Tuple
from typing import List, Tuple, TYPE_CHECKING

import matplotlib
import numpy as np

from qiskit.providers.options import Options
from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, ExperimentData
from qiskit_experiments.data_processing import SkQDA
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle
from qiskit_experiments.warnings import HAS_SKLEARN

try:
if TYPE_CHECKING:
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class MultiStateDiscriminationAnalysis(BaseAnalysis):
r"""This class fits a multi-state discriminator to the data.
Expand All @@ -43,22 +39,13 @@ class MultiStateDiscriminationAnalysis(BaseAnalysis):
Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the
probability of measuring outcome :math:`i` given that state :math:`j` was prepared.
"""
def __init__(self):
"""Setup the analysis.
Raises:
DataProcessorError: if sklearn is not installed.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

super().__init__()
.. note::
This class requires that scikit-learn is installed.
"""

@classmethod
@HAS_SKLEARN.require_in_call
def _default_options(cls) -> Options:
"""Return default analysis options.
Expand All @@ -76,6 +63,8 @@ def _default_options(cls) -> Options:
)
options.plot = True
options.ax = None
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

options.discriminator = SkQDA(QuadraticDiscriminantAnalysis())
return options

Expand Down
14 changes: 14 additions & 0 deletions qiskit_experiments/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import warnings
from typing import Callable, Optional, Type, Dict

from qiskit.utils.lazy_tester import LazyImportTester


def deprecated_function(
last_version: Optional[str] = None,
Expand Down Expand Up @@ -240,3 +242,15 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


HAS_SKLEARN = LazyImportTester(
{
"sklearn.discriminant_analysis": (
"LinearDiscriminantAnalysis",
"QuadraticDiscriminantAnalysis",
)
},
name="scikit-learn",
install="pip install scikit-learn",
)
9 changes: 9 additions & 0 deletions releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
fixes:
- |
The importing of ``scikit-learn`` was moved from module-level imports
inside of ``try`` blocks to dynamic imports at first usage of the
``scikit-learn`` specific feature. This change should avoid errors in the
installation of ``scikit-learn`` from preventing a user using features of
``qiskit-experiments`` that do not require ``scikit-learn``. See `#1050
<https://github.com/Qiskit/qiskit-experiments/issues/1050>`_.
Loading

0 comments on commit 0a9da6a

Please sign in to comment.