-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delay scikit-learn import until first use #1061
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
74b4a50
Delay scikit-learn import until first use
wshanks 575efd1
Address pylint warnings
wshanks 7d9d957
Add test of warning about optional scikit-learn dependency
wshanks 0a6b980
Add release note
wshanks c826d0f
Apply suggestions from code review
wshanks fc3f5f9
Make sklearn import test sensitive to any error in sklearn import
wshanks 249c427
Document classes that require scikit-learn
wshanks File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
346 changes: 165 additions & 181 deletions
346
qiskit_experiments/data_processing/sklearn_discriminators.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,181 +1,165 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2022. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
"""Discriminators that wrap SKLearn.""" | ||
|
||
from typing import Any, List, Dict | ||
|
||
from qiskit_experiments.data_processing.discriminator import BaseDiscriminator | ||
from qiskit_experiments.data_processing.exceptions import DataProcessorError | ||
|
||
try: | ||
from sklearn.discriminant_analysis import ( | ||
LinearDiscriminantAnalysis, | ||
QuadraticDiscriminantAnalysis, | ||
) | ||
|
||
HAS_SKLEARN = True | ||
except ImportError: | ||
HAS_SKLEARN = False | ||
|
||
|
||
class SkLDA(BaseDiscriminator): | ||
"""A wrapper for the SKlearn linear discriminant analysis.""" | ||
|
||
def __init__(self, lda: "LinearDiscriminantAnalysis"): | ||
""" | ||
Args: | ||
lda: The sklearn linear discriminant analysis. This may be a trained or an | ||
untrained discriminator. | ||
|
||
Raises: | ||
DataProcessorError: if SKlearn could not be imported. | ||
""" | ||
if not HAS_SKLEARN: | ||
raise DataProcessorError( | ||
f"SKlearn is needed to initialize an {self.__class__.__name__}." | ||
) | ||
|
||
self._lda = lda | ||
self.attributes = [ | ||
"coef_", | ||
"intercept_", | ||
"covariance_", | ||
"explained_variance_ratio_", | ||
"means_", | ||
"priors_", | ||
"scalings_", | ||
"xbar_", | ||
"classes_", | ||
"n_features_in_", | ||
"feature_names_in_", | ||
] | ||
|
||
@property | ||
def discriminator(self) -> Any: | ||
"""Return then SKLearn object.""" | ||
return self._lda | ||
|
||
def is_trained(self) -> bool: | ||
"""Return True if the discriminator has been trained on data.""" | ||
return not getattr(self._lda, "classes_", None) is None | ||
|
||
def predict(self, data: List): | ||
"""Wrap the predict method of the LDA.""" | ||
return self._lda.predict(data) | ||
|
||
def fit(self, data: List, labels: List): | ||
"""Fit the LDA. | ||
|
||
Args: | ||
data: The independent data. | ||
labels: The labels corresponding to data. | ||
""" | ||
self._lda.fit(data, labels) | ||
|
||
def config(self) -> Dict[str, Any]: | ||
"""Return the configuration of the LDA.""" | ||
attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} | ||
return {"params": self._lda.get_params(), "attributes": attr_conf} | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "SkLDA": | ||
"""Deserialize from an object.""" | ||
|
||
if not HAS_SKLEARN: | ||
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.") | ||
|
||
lda = LinearDiscriminantAnalysis() | ||
lda.set_params(**config["params"]) | ||
|
||
for name, value in config["attributes"].items(): | ||
if value is not None: | ||
setattr(lda, name, value) | ||
|
||
return SkLDA(lda) | ||
|
||
|
||
class SkQDA(BaseDiscriminator): | ||
"""A wrapper for the SKlearn quadratic discriminant analysis.""" | ||
|
||
def __init__(self, qda: "QuadraticDiscriminantAnalysis"): | ||
""" | ||
Args: | ||
qda: The sklearn quadratic discriminant analysis. This may be a trained or an | ||
untrained discriminator. | ||
|
||
Raises: | ||
DataProcessorError: if SKlearn could not be imported. | ||
""" | ||
if not HAS_SKLEARN: | ||
raise DataProcessorError( | ||
f"SKlearn is needed to initialize an {self.__class__.__name__}." | ||
) | ||
|
||
self._qda = qda | ||
self.attributes = [ | ||
"coef_", | ||
"intercept_", | ||
"covariance_", | ||
"explained_variance_ratio_", | ||
"means_", | ||
"priors_", | ||
"scalings_", | ||
"xbar_", | ||
"classes_", | ||
"n_features_in_", | ||
"feature_names_in_", | ||
"rotations_", | ||
] | ||
|
||
@property | ||
def discriminator(self) -> Any: | ||
"""Return then SKLearn object.""" | ||
return self._qda | ||
|
||
def is_trained(self) -> bool: | ||
"""Return True if the discriminator has been trained on data.""" | ||
return not getattr(self._qda, "classes_", None) is None | ||
|
||
def predict(self, data: List): | ||
"""Wrap the predict method of the QDA.""" | ||
return self._qda.predict(data) | ||
|
||
def fit(self, data: List, labels: List): | ||
"""Fit the QDA. | ||
|
||
Args: | ||
data: The independent data. | ||
labels: The labels corresponding to data. | ||
""" | ||
self._qda.fit(data, labels) | ||
|
||
def config(self) -> Dict[str, Any]: | ||
"""Return the configuration of the QDA.""" | ||
attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} | ||
return {"params": self._qda.get_params(), "attributes": attr_conf} | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "SkQDA": | ||
"""Deserialize from an object.""" | ||
|
||
if not HAS_SKLEARN: | ||
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.") | ||
|
||
qda = QuadraticDiscriminantAnalysis() | ||
qda.set_params(**config["params"]) | ||
|
||
for name, value in config["attributes"].items(): | ||
if value is not None: | ||
setattr(qda, name, value) | ||
|
||
return SkQDA(qda) | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2022. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
"""Discriminators that wrap SKLearn.""" | ||
|
||
from typing import Any, List, Dict, TYPE_CHECKING | ||
|
||
from qiskit_experiments.data_processing.discriminator import BaseDiscriminator | ||
from qiskit_experiments.warnings import HAS_SKLEARN | ||
|
||
if TYPE_CHECKING: | ||
from sklearn.discriminant_analysis import ( | ||
LinearDiscriminantAnalysis, | ||
QuadraticDiscriminantAnalysis, | ||
) | ||
|
||
|
||
class SkLDA(BaseDiscriminator): | ||
"""A wrapper for the SKlearn linear discriminant analysis.""" | ||
|
||
def __init__(self, lda: "LinearDiscriminantAnalysis"): | ||
""" | ||
Args: | ||
lda: The sklearn linear discriminant analysis. This may be a trained or an | ||
untrained discriminator. | ||
|
||
Raises: | ||
DataProcessorError: if SKlearn could not be imported. | ||
""" | ||
self._lda = lda | ||
self.attributes = [ | ||
"coef_", | ||
"intercept_", | ||
"covariance_", | ||
"explained_variance_ratio_", | ||
"means_", | ||
"priors_", | ||
"scalings_", | ||
"xbar_", | ||
"classes_", | ||
"n_features_in_", | ||
"feature_names_in_", | ||
] | ||
|
||
@property | ||
def discriminator(self) -> Any: | ||
"""Return then SKLearn object.""" | ||
return self._lda | ||
|
||
def is_trained(self) -> bool: | ||
"""Return True if the discriminator has been trained on data.""" | ||
return not getattr(self._lda, "classes_", None) is None | ||
|
||
def predict(self, data: List): | ||
"""Wrap the predict method of the LDA.""" | ||
return self._lda.predict(data) | ||
|
||
def fit(self, data: List, labels: List): | ||
"""Fit the LDA. | ||
|
||
Args: | ||
data: The independent data. | ||
labels: The labels corresponding to data. | ||
""" | ||
self._lda.fit(data, labels) | ||
|
||
def config(self) -> Dict[str, Any]: | ||
"""Return the configuration of the LDA.""" | ||
attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes} | ||
return {"params": self._lda.get_params(), "attributes": attr_conf} | ||
|
||
@classmethod | ||
@HAS_SKLEARN.require_in_call | ||
def from_config(cls, config: Dict[str, Any]) -> "SkLDA": | ||
"""Deserialize from an object.""" | ||
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis | ||
|
||
lda = LinearDiscriminantAnalysis() | ||
lda.set_params(**config["params"]) | ||
|
||
for name, value in config["attributes"].items(): | ||
if value is not None: | ||
setattr(lda, name, value) | ||
|
||
return SkLDA(lda) | ||
|
||
|
||
class SkQDA(BaseDiscriminator): | ||
"""A wrapper for the SKlearn quadratic discriminant analysis.""" | ||
|
||
def __init__(self, qda: "QuadraticDiscriminantAnalysis"): | ||
""" | ||
Args: | ||
qda: The sklearn quadratic discriminant analysis. This may be a trained or an | ||
untrained discriminator. | ||
|
||
Raises: | ||
DataProcessorError: if SKlearn could not be imported. | ||
""" | ||
self._qda = qda | ||
self.attributes = [ | ||
"coef_", | ||
"intercept_", | ||
"covariance_", | ||
"explained_variance_ratio_", | ||
"means_", | ||
"priors_", | ||
"scalings_", | ||
"xbar_", | ||
"classes_", | ||
"n_features_in_", | ||
"feature_names_in_", | ||
"rotations_", | ||
] | ||
|
||
@property | ||
def discriminator(self) -> Any: | ||
"""Return then SKLearn object.""" | ||
return self._qda | ||
|
||
def is_trained(self) -> bool: | ||
"""Return True if the discriminator has been trained on data.""" | ||
return not getattr(self._qda, "classes_", None) is None | ||
|
||
def predict(self, data: List): | ||
"""Wrap the predict method of the QDA.""" | ||
return self._qda.predict(data) | ||
|
||
def fit(self, data: List, labels: List): | ||
"""Fit the QDA. | ||
|
||
Args: | ||
data: The independent data. | ||
labels: The labels corresponding to data. | ||
""" | ||
self._qda.fit(data, labels) | ||
|
||
def config(self) -> Dict[str, Any]: | ||
"""Return the configuration of the QDA.""" | ||
attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes} | ||
return {"params": self._qda.get_params(), "attributes": attr_conf} | ||
|
||
@classmethod | ||
@HAS_SKLEARN.require_in_call | ||
def from_config(cls, config: Dict[str, Any]) -> "SkQDA": | ||
"""Deserialize from an object.""" | ||
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis | ||
|
||
qda = QuadraticDiscriminantAnalysis() | ||
qda.set_params(**config["params"]) | ||
|
||
for name, value in config["attributes"].items(): | ||
if value is not None: | ||
setattr(qda, name, value) | ||
|
||
return SkQDA(qda) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These classes should probably have a note in the docstring that they require sklearn to run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added notes to the classes that say they require scikit-learn. If you want to get very pedantic, the classes here only require scikit-learn for the
from_config
methods. The__init__
methods could work with a custom class if it satisfied the Analysis interface of scikit-learn.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. In the future it would be helpful to have a how-to that explains how to work with discriminators and implement a custom class.