diff --git a/qiskit_experiments/data_processing/sklearn_discriminators.py b/qiskit_experiments/data_processing/sklearn_discriminators.py index ad43e13c7c..49d3072004 100644 --- a/qiskit_experiments/data_processing/sklearn_discriminators.py +++ b/qiskit_experiments/data_processing/sklearn_discriminators.py @@ -12,24 +12,24 @@ """Discriminators that wrap SKLearn.""" -from typing import Any, List, Dict +from typing import Any, List, Dict, TYPE_CHECKING from qiskit_experiments.data_processing.discriminator import BaseDiscriminator -from qiskit_experiments.data_processing.exceptions import DataProcessorError +from qiskit_experiments.warnings import HAS_SKLEARN -try: +if TYPE_CHECKING: from sklearn.discriminant_analysis import ( LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis, ) - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - class SkLDA(BaseDiscriminator): - """A wrapper for the SKlearn linear discriminant analysis.""" + """A wrapper for the scikit-learn linear discriminant analysis. + + .. note:: + This class requires that scikit-learn is installed. + """ def __init__(self, lda: "LinearDiscriminantAnalysis"): """ @@ -40,11 +40,6 @@ def __init__(self, lda: "LinearDiscriminantAnalysis"): Raises: DataProcessorError: if SKlearn could not be imported. """ - if not HAS_SKLEARN: - raise DataProcessorError( - f"SKlearn is needed to initialize an {self.__class__.__name__}." - ) - self._lda = lda self.attributes = [ "coef_", @@ -88,11 +83,10 @@ def config(self) -> Dict[str, Any]: return {"params": self._lda.get_params(), "attributes": attr_conf} @classmethod + @HAS_SKLEARN.require_in_call def from_config(cls, config: Dict[str, Any]) -> "SkLDA": """Deserialize from an object.""" - - if not HAS_SKLEARN: - raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.") + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis lda = LinearDiscriminantAnalysis() lda.set_params(**config["params"]) @@ -105,7 +99,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SkLDA": class SkQDA(BaseDiscriminator): - """A wrapper for the SKlearn quadratic discriminant analysis.""" + """A wrapper for the SKlearn quadratic discriminant analysis. + + .. note:: + This class requires that scikit-learn is installed. + """ def __init__(self, qda: "QuadraticDiscriminantAnalysis"): """ @@ -116,11 +114,6 @@ def __init__(self, qda: "QuadraticDiscriminantAnalysis"): Raises: DataProcessorError: if SKlearn could not be imported. """ - if not HAS_SKLEARN: - raise DataProcessorError( - f"SKlearn is needed to initialize an {self.__class__.__name__}." - ) - self._qda = qda self.attributes = [ "coef_", @@ -165,11 +158,10 @@ def config(self) -> Dict[str, Any]: return {"params": self._qda.get_params(), "attributes": attr_conf} @classmethod + @HAS_SKLEARN.require_in_call def from_config(cls, config: Dict[str, Any]) -> "SkQDA": """Deserialize from an object.""" - - if not HAS_SKLEARN: - raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.") + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis qda = QuadraticDiscriminantAnalysis() qda.set_params(**config["params"]) diff --git a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py index be39380b22..94b9bd990a 100644 --- a/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/multi_state_discrimination_analysis.py @@ -12,7 +12,7 @@ """Multi state discrimination analysis.""" -from typing import List, Tuple +from typing import List, Tuple, TYPE_CHECKING import matplotlib import numpy as np @@ -20,16 +20,12 @@ from qiskit.providers.options import Options from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, ExperimentData from qiskit_experiments.data_processing import SkQDA -from qiskit_experiments.data_processing.exceptions import DataProcessorError from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle +from qiskit_experiments.warnings import HAS_SKLEARN -try: +if TYPE_CHECKING: from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - class MultiStateDiscriminationAnalysis(BaseAnalysis): r"""This class fits a multi-state discriminator to the data. @@ -43,22 +39,13 @@ class MultiStateDiscriminationAnalysis(BaseAnalysis): Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the probability of measuring outcome :math:`i` given that state :math:`j` was prepared. - """ - def __init__(self): - """Setup the analysis. - - Raises: - DataProcessorError: if sklearn is not installed. - """ - if not HAS_SKLEARN: - raise DataProcessorError( - f"SKlearn is needed to initialize an {self.__class__.__name__}." - ) - - super().__init__() + .. note:: + This class requires that scikit-learn is installed. + """ @classmethod + @HAS_SKLEARN.require_in_call def _default_options(cls) -> Options: """Return default analysis options. @@ -76,6 +63,8 @@ def _default_options(cls) -> Options: ) options.plot = True options.ax = None + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis + options.discriminator = SkQDA(QuadraticDiscriminantAnalysis()) return options diff --git a/qiskit_experiments/warnings.py b/qiskit_experiments/warnings.py index 4f9742e5b5..3758ac2764 100644 --- a/qiskit_experiments/warnings.py +++ b/qiskit_experiments/warnings.py @@ -16,6 +16,8 @@ import warnings from typing import Callable, Optional, Type, Dict +from qiskit.utils.lazy_tester import LazyImportTester + def deprecated_function( last_version: Optional[str] = None, @@ -240,3 +242,15 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +HAS_SKLEARN = LazyImportTester( + { + "sklearn.discriminant_analysis": ( + "LinearDiscriminantAnalysis", + "QuadraticDiscriminantAnalysis", + ) + }, + name="scikit-learn", + install="pip install scikit-learn", +) diff --git a/releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml b/releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml new file mode 100644 index 0000000000..56029aa414 --- /dev/null +++ b/releasenotes/notes/sklearn-imports-c82155c0a2c81811.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + The importing of ``scikit-learn`` was moved from module-level imports + inside of ``try`` blocks to dynamic imports at first usage of the + ``scikit-learn`` specific feature. This change should avoid errors in the + installation of ``scikit-learn`` from preventing a user using features of + ``qiskit-experiments`` that do not require ``scikit-learn``. See `#1050 + `_. diff --git a/test/data_processing/test_discriminator.py b/test/data_processing/test_discriminator.py index dc0fea629a..7ce095ff7c 100644 --- a/test/data_processing/test_discriminator.py +++ b/test/data_processing/test_discriminator.py @@ -1,123 +1,122 @@ -# This code is part of Qiskit. -# -# (C) Copyright IBM 2022. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -"""Tests for the serializable discriminator objects.""" - -from test.base import QiskitExperimentsTestCase -from functools import wraps -from unittest import SkipTest -import numpy as np - -from qiskit_experiments.data_processing import SkLDA, SkQDA - -try: - from sklearn.discriminant_analysis import ( - LinearDiscriminantAnalysis, - QuadraticDiscriminantAnalysis, - ) - - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - - -def requires_sklearn(func): - """Decorator to check for SKLearn.""" - - @wraps(func) - def wrapper(*args, **kwargs): - if not HAS_SKLEARN: - raise SkipTest("SKLearn is required for test.") - - func(*args, **kwargs) - - return wrapper - - -class TestDiscriminator(QiskitExperimentsTestCase): - """Tests for the discriminator.""" - - @requires_sklearn - def test_lda_serialization(self): - """Test the serialization of a lda.""" - - sk_lda = LinearDiscriminantAnalysis() - sk_lda.fit([[-1, 0], [1, 0], [-1.1, 0], [0.9, 0.1]], [0, 1, 0, 1]) - - self.assertTrue(sk_lda.predict([[1.1, 0]])[0], 1) - - lda = SkLDA(sk_lda) - - self.assertTrue(lda.is_trained()) - self.assertTrue(lda.predict([[1.1, 0]])[0], 1) - - def check_lda(lda1, lda2): - test_data = [[1.1, 0], [0.1, 0], [-2, 0]] - - lda1_y = lda1.predict(test_data) - lda2_y = lda2.predict(test_data) - - if len(lda1_y) != len(lda2_y): - return False - - for idx, y_val1 in enumerate(lda1_y): - if lda2_y[idx] != y_val1: - return False - - for attribute in lda1.attributes: - if not np.allclose( - getattr(lda1.discriminator, attribute, np.array([])), - getattr(lda2.discriminator, attribute, np.array([])), - ): - return False - - return True - - self.assertRoundTripSerializable(lda, check_lda) - - @requires_sklearn - def test_qda_serialization(self): - """Test the serialization of a qda.""" - - sk_qda = QuadraticDiscriminantAnalysis() - sk_qda.fit([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], [0, 0, 0, 1, 1, 1]) - - self.assertTrue(sk_qda.predict([[1.1, 3]])[0], 1) - - qda = SkQDA(sk_qda) - - self.assertTrue(qda.is_trained()) - self.assertTrue(qda.predict([[1.1, 3]])[0], 1) - - def check_qda(qda1, qda2): - test_data = [[1.1, 0], [0.1, 0], [-2, 0]] - - qda1_y = qda1.predict(test_data) - qda2_y = qda2.predict(test_data) - - if len(qda1_y) != len(qda2_y): - return False - - for idx, y_val1 in enumerate(qda1_y): - if qda2_y[idx] != y_val1: - return False - - for attribute in qda1.attributes: - if not np.allclose( - getattr(qda1.discriminator, attribute, np.array([])), - getattr(qda2.discriminator, attribute, np.array([])), - ): - return False - - return True - - self.assertRoundTripSerializable(qda, check_qda) +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests for the serializable discriminator objects.""" + +from test.base import QiskitExperimentsTestCase +from functools import wraps +from unittest import SkipTest +import numpy as np + +from qiskit.exceptions import MissingOptionalLibraryError + +from qiskit_experiments.data_processing import SkLDA, SkQDA +from qiskit_experiments.warnings import HAS_SKLEARN + + +def requires_sklearn(func): + """Decorator to check for SKLearn.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + HAS_SKLEARN.require_now("SKLearn discriminator testing") + except MissingOptionalLibraryError as exc: + raise SkipTest("SKLearn is required for test.") from exc + + func(*args, **kwargs) + + return wrapper + + +class TestDiscriminator(QiskitExperimentsTestCase): + """Tests for the discriminator.""" + + @requires_sklearn + def test_lda_serialization(self): + """Test the serialization of a lda.""" + + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + + sk_lda = LinearDiscriminantAnalysis() + sk_lda.fit([[-1, 0], [1, 0], [-1.1, 0], [0.9, 0.1]], [0, 1, 0, 1]) + + self.assertTrue(sk_lda.predict([[1.1, 0]])[0], 1) + + lda = SkLDA(sk_lda) + + self.assertTrue(lda.is_trained()) + self.assertTrue(lda.predict([[1.1, 0]])[0], 1) + + def check_lda(lda1, lda2): + test_data = [[1.1, 0], [0.1, 0], [-2, 0]] + + lda1_y = lda1.predict(test_data) + lda2_y = lda2.predict(test_data) + + if len(lda1_y) != len(lda2_y): + return False + + for idx, y_val1 in enumerate(lda1_y): + if lda2_y[idx] != y_val1: + return False + + for attribute in lda1.attributes: + if not np.allclose( + getattr(lda1.discriminator, attribute, np.array([])), + getattr(lda2.discriminator, attribute, np.array([])), + ): + return False + + return True + + self.assertRoundTripSerializable(lda, check_lda) + + @requires_sklearn + def test_qda_serialization(self): + """Test the serialization of a qda.""" + + from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis + + sk_qda = QuadraticDiscriminantAnalysis() + sk_qda.fit([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]], [0, 0, 0, 1, 1, 1]) + + self.assertTrue(sk_qda.predict([[1.1, 3]])[0], 1) + + qda = SkQDA(sk_qda) + + self.assertTrue(qda.is_trained()) + self.assertTrue(qda.predict([[1.1, 3]])[0], 1) + + def check_qda(qda1, qda2): + test_data = [[1.1, 0], [0.1, 0], [-2, 0]] + + qda1_y = qda1.predict(test_data) + qda2_y = qda2.predict(test_data) + + if len(qda1_y) != len(qda2_y): + return False + + for idx, y_val1 in enumerate(qda1_y): + if qda2_y[idx] != y_val1: + return False + + for attribute in qda1.attributes: + if not np.allclose( + getattr(qda1.discriminator, attribute, np.array([])), + getattr(qda2.discriminator, attribute, np.array([])), + ): + return False + + return True + + self.assertRoundTripSerializable(qda, check_qda) diff --git a/test/framework/test_warnings.py b/test/framework/test_warnings.py index 45ab7c90de..365176ffae 100644 --- a/test/framework/test_warnings.py +++ b/test/framework/test_warnings.py @@ -13,6 +13,9 @@ # pylint: disable=unused-argument, unused-variable """Test warning helper.""" +import subprocess +import sys +import textwrap from test.base import QiskitExperimentsTestCase from qiskit_experiments.framework import BaseExperiment @@ -86,3 +89,45 @@ def __init__(self, physical_qubits): with self.assertWarns(DeprecationWarning): instance = OldExperiment(qubit=0) self.assertEqual(instance._physical_qubits, (0,)) + + def test_warn_sklearn(self): + """Test that a suggestion to import scikit-learn is given when appropriate""" + script = """ + import builtins + disallowed_imports = {"sklearn"} + old_import = builtins.__import__ + def guarded_import(name, *args, **kwargs): + if name in disallowed_imports: + raise import_error(f"Import of {name} not allowed!") + return old_import(name, *args, **kwargs) + builtins.__import__ = guarded_import + # Raise Exception on imports so that ImportError can't be caught + import_error = Exception + import qiskit_experiments + print("qiskit_experiments imported!") + # Raise ImportError so the guard can catch it + import_error = ImportError + from qiskit_experiments.data_processing.sklearn_discriminators import SkLDA + SkLDA.from_config({}) + """ + script = textwrap.dedent(script) + + proc = subprocess.run( + [sys.executable, "-c", script], check=False, text=True, capture_output=True + ) + + self.assertTrue( + proc.stdout.startswith("qiskit_experiments imported!"), + msg="Failed to import qiskit_experiments without sklearn", + ) + + self.assertNotEqual( + proc.returncode, + 0, + msg="scikit-learn usage did not error without scikit-learn available", + ) + self.assertTrue( + "qiskit.exceptions.MissingOptionalLibraryError" in proc.stderr + and "scikit-learn" in proc.stderr, + msg="scikit-learn import guard did not run on scikit-learn usage", + ) diff --git a/test/library/characterization/test_multi_state_discrimination.py b/test/library/characterization/test_multi_state_discrimination.py index 4beb0d4d8d..508299cd75 100644 --- a/test/library/characterization/test_multi_state_discrimination.py +++ b/test/library/characterization/test_multi_state_discrimination.py @@ -11,13 +11,35 @@ # that they have been altered from the originals. """Test the multi state discrimination experiments.""" +from functools import wraps from test.base import QiskitExperimentsTestCase +from unittest import SkipTest + from ddt import ddt, data from qiskit import pulse +from qiskit.exceptions import MissingOptionalLibraryError + from qiskit_experiments.library import MultiStateDiscrimination from qiskit_experiments.test.pulse_backend import SingleTransmonTestBackend +from qiskit_experiments.warnings import HAS_SKLEARN + + +def requires_sklearn(func): + """Decorator to check for SKLearn.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + HAS_SKLEARN.require_now("SKLearn discriminator testing") + except MissingOptionalLibraryError as exc: + raise SkipTest("SKLearn is required for test.") from exc + + func(*args, **kwargs) + + return wrapper + @ddt class TestMultiStateDiscrimination(QiskitExperimentsTestCase): @@ -52,6 +74,7 @@ def setUp(self): self.schedules = {"x12": x12} @data(2, 3) + @requires_sklearn def test_circuit_generation(self, n_states): """Test the experiment circuit generation""" exp = MultiStateDiscrimination( @@ -63,6 +86,7 @@ def test_circuit_generation(self, n_states): self.assertEqual(exp.circuits()[-1].metadata["label"], n_states - 1) @data(2, 3) + @requires_sklearn def test_discrimination_analysis(self, n_states): """Test the discrimination analysis""" exp = MultiStateDiscrimination(