Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay scikit-learn import until first use #1061

Merged
merged 7 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 165 additions & 181 deletions qiskit_experiments/data_processing/sklearn_discriminators.py
Original file line number Diff line number Diff line change
@@ -1,181 +1,165 @@
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Discriminators that wrap SKLearn."""

from typing import Any, List, Dict

from qiskit_experiments.data_processing.discriminator import BaseDiscriminator
from qiskit_experiments.data_processing.exceptions import DataProcessorError

try:
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class SkLDA(BaseDiscriminator):
"""A wrapper for the SKlearn linear discriminant analysis."""

def __init__(self, lda: "LinearDiscriminantAnalysis"):
"""
Args:
lda: The sklearn linear discriminant analysis. This may be a trained or an
untrained discriminator.

Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._lda = lda
self.attributes = [
"coef_",
"intercept_",
"covariance_",
"explained_variance_ratio_",
"means_",
"priors_",
"scalings_",
"xbar_",
"classes_",
"n_features_in_",
"feature_names_in_",
]

@property
def discriminator(self) -> Any:
"""Return then SKLearn object."""
return self._lda

def is_trained(self) -> bool:
"""Return True if the discriminator has been trained on data."""
return not getattr(self._lda, "classes_", None) is None

def predict(self, data: List):
"""Wrap the predict method of the LDA."""
return self._lda.predict(data)

def fit(self, data: List, labels: List):
"""Fit the LDA.

Args:
data: The independent data.
labels: The labels corresponding to data.
"""
self._lda.fit(data, labels)

def config(self) -> Dict[str, Any]:
"""Return the configuration of the LDA."""
attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes}
return {"params": self._lda.get_params(), "attributes": attr_conf}

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SkLDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")

lda = LinearDiscriminantAnalysis()
lda.set_params(**config["params"])

for name, value in config["attributes"].items():
if value is not None:
setattr(lda, name, value)

return SkLDA(lda)


class SkQDA(BaseDiscriminator):
"""A wrapper for the SKlearn quadratic discriminant analysis."""

def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
"""
Args:
qda: The sklearn quadratic discriminant analysis. This may be a trained or an
untrained discriminator.

Raises:
DataProcessorError: if SKlearn could not be imported.
"""
if not HAS_SKLEARN:
raise DataProcessorError(
f"SKlearn is needed to initialize an {self.__class__.__name__}."
)

self._qda = qda
self.attributes = [
"coef_",
"intercept_",
"covariance_",
"explained_variance_ratio_",
"means_",
"priors_",
"scalings_",
"xbar_",
"classes_",
"n_features_in_",
"feature_names_in_",
"rotations_",
]

@property
def discriminator(self) -> Any:
"""Return then SKLearn object."""
return self._qda

def is_trained(self) -> bool:
"""Return True if the discriminator has been trained on data."""
return not getattr(self._qda, "classes_", None) is None

def predict(self, data: List):
"""Wrap the predict method of the QDA."""
return self._qda.predict(data)

def fit(self, data: List, labels: List):
"""Fit the QDA.

Args:
data: The independent data.
labels: The labels corresponding to data.
"""
self._qda.fit(data, labels)

def config(self) -> Dict[str, Any]:
"""Return the configuration of the QDA."""
attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes}
return {"params": self._qda.get_params(), "attributes": attr_conf}

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SkQDA":
"""Deserialize from an object."""

if not HAS_SKLEARN:
raise DataProcessorError(f"SKlearn is needed to initialize an {cls.__name__}.")

qda = QuadraticDiscriminantAnalysis()
qda.set_params(**config["params"])

for name, value in config["attributes"].items():
if value is not None:
setattr(qda, name, value)

return SkQDA(qda)
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Discriminators that wrap SKLearn."""

from typing import Any, List, Dict, TYPE_CHECKING

from qiskit_experiments.data_processing.discriminator import BaseDiscriminator
from qiskit_experiments.warnings import HAS_SKLEARN

if TYPE_CHECKING:
from sklearn.discriminant_analysis import (
LinearDiscriminantAnalysis,
QuadraticDiscriminantAnalysis,
)


class SkLDA(BaseDiscriminator):
"""A wrapper for the SKlearn linear discriminant analysis."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These classes should probably have a note in the docstring that they require sklearn to run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added notes to the classes that say they require scikit-learn. If you want to get very pedantic, the classes here only require scikit-learn for the from_config methods. The __init__ methods could work with a custom class if it satisfied the Analysis interface of scikit-learn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. In the future it would be helpful to have a how-to that explains how to work with discriminators and implement a custom class.


def __init__(self, lda: "LinearDiscriminantAnalysis"):
"""
Args:
lda: The sklearn linear discriminant analysis. This may be a trained or an
untrained discriminator.

Raises:
DataProcessorError: if SKlearn could not be imported.
"""
self._lda = lda
self.attributes = [
"coef_",
"intercept_",
"covariance_",
"explained_variance_ratio_",
"means_",
"priors_",
"scalings_",
"xbar_",
"classes_",
"n_features_in_",
"feature_names_in_",
]

@property
def discriminator(self) -> Any:
"""Return then SKLearn object."""
return self._lda

def is_trained(self) -> bool:
"""Return True if the discriminator has been trained on data."""
return not getattr(self._lda, "classes_", None) is None

def predict(self, data: List):
"""Wrap the predict method of the LDA."""
return self._lda.predict(data)

def fit(self, data: List, labels: List):
"""Fit the LDA.

Args:
data: The independent data.
labels: The labels corresponding to data.
"""
self._lda.fit(data, labels)

def config(self) -> Dict[str, Any]:
"""Return the configuration of the LDA."""
attr_conf = {attr: getattr(self._lda, attr, None) for attr in self.attributes}
return {"params": self._lda.get_params(), "attributes": attr_conf}

@classmethod
@HAS_SKLEARN.require_in_call
def from_config(cls, config: Dict[str, Any]) -> "SkLDA":
"""Deserialize from an object."""
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()
lda.set_params(**config["params"])

for name, value in config["attributes"].items():
if value is not None:
setattr(lda, name, value)

return SkLDA(lda)


class SkQDA(BaseDiscriminator):
"""A wrapper for the SKlearn quadratic discriminant analysis."""

def __init__(self, qda: "QuadraticDiscriminantAnalysis"):
"""
Args:
qda: The sklearn quadratic discriminant analysis. This may be a trained or an
untrained discriminator.

Raises:
DataProcessorError: if SKlearn could not be imported.
"""
self._qda = qda
self.attributes = [
"coef_",
"intercept_",
"covariance_",
"explained_variance_ratio_",
"means_",
"priors_",
"scalings_",
"xbar_",
"classes_",
"n_features_in_",
"feature_names_in_",
"rotations_",
]

@property
def discriminator(self) -> Any:
"""Return then SKLearn object."""
return self._qda

def is_trained(self) -> bool:
"""Return True if the discriminator has been trained on data."""
return not getattr(self._qda, "classes_", None) is None

def predict(self, data: List):
"""Wrap the predict method of the QDA."""
return self._qda.predict(data)

def fit(self, data: List, labels: List):
"""Fit the QDA.

Args:
data: The independent data.
labels: The labels corresponding to data.
"""
self._qda.fit(data, labels)

def config(self) -> Dict[str, Any]:
"""Return the configuration of the QDA."""
attr_conf = {attr: getattr(self._qda, attr, None) for attr in self.attributes}
return {"params": self._qda.get_params(), "attributes": attr_conf}

@classmethod
@HAS_SKLEARN.require_in_call
def from_config(cls, config: Dict[str, Any]) -> "SkQDA":
"""Deserialize from an object."""
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

qda = QuadraticDiscriminantAnalysis()
qda.set_params(**config["params"])

for name, value in config["attributes"].items():
if value is not None:
setattr(qda, name, value)

return SkQDA(qda)
Loading