diff --git a/src/safeds/ml/classical/classification/_support_vector_machine.py b/src/safeds/ml/classical/classification/_support_vector_machine.py index 552e23403..e9c1edafc 100644 --- a/src/safeds/ml/classical/classification/_support_vector_machine.py +++ b/src/safeds/ml/classical/classification/_support_vector_machine.py @@ -13,14 +13,30 @@ class SupportVectorMachine(Classifier): - """Support vector machine.""" + """ + Support vector machine. - def __init__(self) -> None: + Parameters + ---------- + c: float + The strength of regularization. Must be strictly positive. + + Raises + ------ + ValueError + If `c` is less than or equal to 0. + """ + + def __init__(self, c: float = 1.0) -> None: # Internal state self._wrapped_classifier: sk_SVC | None = None self._feature_names: list[str] | None = None self._target_name: str | None = None + if c <= 0: + raise ValueError("The strength of regularization given by the c parameter must be strictly positive.") + self._c = c + def fit(self, training_set: TaggedTable) -> SupportVectorMachine: """ Create a copy of this classifier and fit it with the given training data. @@ -42,10 +58,10 @@ def fit(self, training_set: TaggedTable) -> SupportVectorMachine: LearningError If the training data contains invalid values or if the training failed. """ - wrapped_classifier = sk_SVC() + wrapped_classifier = sk_SVC(C=self._c) fit(wrapped_classifier, training_set) - result = SupportVectorMachine() + result = SupportVectorMachine(self._c) result._wrapped_classifier = wrapped_classifier result._feature_names = training_set.features.column_names result._target_name = training_set.target.name diff --git a/src/safeds/ml/classical/regression/_support_vector_machine.py b/src/safeds/ml/classical/regression/_support_vector_machine.py index c4c2ee543..742a83c78 100644 --- a/src/safeds/ml/classical/regression/_support_vector_machine.py +++ b/src/safeds/ml/classical/regression/_support_vector_machine.py @@ -13,14 +13,30 @@ class SupportVectorMachine(Regressor): - """Support vector machine.""" + """ + Support vector machine. - def __init__(self) -> None: + Parameters + ---------- + c: float + The strength of regularization. Must be strictly positive. + + Raises + ------ + ValueError + If `c` is less than or equal to 0. + """ + + def __init__(self, c: float = 1.0) -> None: # Internal state self._wrapped_regressor: sk_SVR | None = None self._feature_names: list[str] | None = None self._target_name: str | None = None + if c <= 0: + raise ValueError("The strength of regularization given by the c parameter must be strictly positive.") + self._c = c + def fit(self, training_set: TaggedTable) -> SupportVectorMachine: """ Create a copy of this regressor and fit it with the given training data. @@ -42,10 +58,10 @@ def fit(self, training_set: TaggedTable) -> SupportVectorMachine: LearningError If the training data contains invalid values or if the training failed. """ - wrapped_regressor = sk_SVR() + wrapped_regressor = sk_SVR(C=self._c) fit(wrapped_regressor, training_set) - result = SupportVectorMachine() + result = SupportVectorMachine(self._c) result._wrapped_regressor = wrapped_regressor result._feature_names = training_set.features.column_names result._target_name = training_set.target.name diff --git a/tests/safeds/ml/classical/classification/test_support_vector_machine.py b/tests/safeds/ml/classical/classification/test_support_vector_machine.py new file mode 100644 index 000000000..606614735 --- /dev/null +++ b/tests/safeds/ml/classical/classification/test_support_vector_machine.py @@ -0,0 +1,27 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.ml.classical.classification import SupportVectorMachine + + +@pytest.fixture() +def training_set() -> TaggedTable: + table = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]}) + return table.tag_columns(target_name="col1", feature_names=["col2"]) + + +class TestC: + def test_should_be_passed_to_fitted_model(self, training_set: TaggedTable) -> None: + fitted_model = SupportVectorMachine(c=2).fit(training_set=training_set) + assert fitted_model._c == 2 + + def test_should_be_passed_to_sklearn(self, training_set: TaggedTable) -> None: + fitted_model = SupportVectorMachine(c=2).fit(training_set) + assert fitted_model._wrapped_classifier is not None + assert fitted_model._wrapped_classifier.C == 2 + + def test_should_raise_if_less_than_or_equal_to_0(self) -> None: + with pytest.raises( + ValueError, + match="The strength of regularization given by the c parameter must be strictly positive.", + ): + SupportVectorMachine(c=-1) diff --git a/tests/safeds/ml/classical/regression/test_support_vector_machine.py b/tests/safeds/ml/classical/regression/test_support_vector_machine.py new file mode 100644 index 000000000..04166a951 --- /dev/null +++ b/tests/safeds/ml/classical/regression/test_support_vector_machine.py @@ -0,0 +1,27 @@ +import pytest +from safeds.data.tabular.containers import Table, TaggedTable +from safeds.ml.classical.regression import SupportVectorMachine + + +@pytest.fixture() +def training_set() -> TaggedTable: + table = Table.from_dict({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]}) + return table.tag_columns(target_name="col1", feature_names=["col2"]) + + +class TestC: + def test_should_be_passed_to_fitted_model(self, training_set: TaggedTable) -> None: + fitted_model = SupportVectorMachine(c=2).fit(training_set=training_set) + assert fitted_model._c == 2 + + def test_should_be_passed_to_sklearn(self, training_set: TaggedTable) -> None: + fitted_model = SupportVectorMachine(c=2).fit(training_set) + assert fitted_model._wrapped_regressor is not None + assert fitted_model._wrapped_regressor.C == 2 + + def test_should_raise_if_less_than_or_equal_to_0(self) -> None: + with pytest.raises( + ValueError, + match="The strength of regularization given by the c parameter must be strictly positive.", + ): + SupportVectorMachine(c=-1)