diff --git a/src/safeds/ml/classical/regression/_elastic_net_regression.py b/src/safeds/ml/classical/regression/_elastic_net_regression.py index 05de26dfb..878544be5 100644 --- a/src/safeds/ml/classical/regression/_elastic_net_regression.py +++ b/src/safeds/ml/classical/regression/_elastic_net_regression.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from sklearn.linear_model import ElasticNet as sk_ElasticNet @@ -15,7 +16,22 @@ class ElasticNetRegression(Regressor): """Elastic net regression.""" - def __init__(self) -> None: + def __init__(self, lasso_ratio: float = 0.5) -> None: + if lasso_ratio < 0 or lasso_ratio > 1: + raise ValueError("lasso_ratio must be between 0 and 1.") + elif lasso_ratio == 0: + warnings.warn( + "ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression." + " Use RidgeRegression instead for better numerical stability.", + stacklevel=1, + ) + elif lasso_ratio == 1: + warnings.warn( + "ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression." + " Use LassoRegression instead for better numerical stability.", + stacklevel=1, + ) + self.lasso_ratio = lasso_ratio self._wrapped_regressor: sk_ElasticNet | None = None self._feature_names: list[str] | None = None self._target_name: str | None = None @@ -41,10 +57,10 @@ def fit(self, training_set: TaggedTable) -> ElasticNetRegression: LearningError If the training data contains invalid values or if the training failed. """ - wrapped_regressor = sk_ElasticNet() + wrapped_regressor = sk_ElasticNet(l1_ratio=self.lasso_ratio) fit(wrapped_regressor, training_set) - result = ElasticNetRegression() + result = ElasticNetRegression(self.lasso_ratio) result._wrapped_regressor = wrapped_regressor result._feature_names = training_set.features.column_names result._target_name = training_set.target.name diff --git a/tests/safeds/ml/classical/regression/test_elastic_net_regression.py b/tests/safeds/ml/classical/regression/test_elastic_net_regression.py new file mode 100644 index 000000000..3ea0d3d4d --- /dev/null +++ b/tests/safeds/ml/classical/regression/test_elastic_net_regression.py @@ -0,0 +1,39 @@ +import pytest +from safeds.data.tabular.containers import Table +from safeds.ml.classical.regression._elastic_net_regression import ElasticNetRegression + + +def test_lasso_ratio_valid() -> None: + training_set = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]}) + tagged_training_set = training_set.tag_columns(target_name="col1", feature_names=["col2"]) + lasso_ratio = 0.3 + + elastic_net_regression = ElasticNetRegression(lasso_ratio).fit(tagged_training_set) + assert elastic_net_regression._wrapped_regressor is not None + assert elastic_net_regression._wrapped_regressor.l1_ratio == lasso_ratio + + +def test_lasso_ratio_invalid() -> None: + with pytest.raises(ValueError, match="lasso_ratio must be between 0 and 1."): + ElasticNetRegression(-1) + + +def test_lasso_ratio_zero() -> None: + with pytest.warns( + UserWarning, + match="ElasticNetRegression with lasso_ratio = 0 is essentially RidgeRegression." + " Use RidgeRegression instead for better numerical stability.", + ): + ElasticNetRegression(0) + + +def test_lasso_ratio_one() -> None: + with pytest.warns( + UserWarning, + match="ElasticNetRegression with lasso_ratio = 0 is essentially LassoRegression." + " Use LassoRegression instead for better numerical stability.", + ): + ElasticNetRegression(1) + + +# (Default parameter is tested in `test_regressor.py`.)