From d511c3eed779e64fc53499e7c2eb2e8292955645 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Fri, 26 May 2023 12:46:45 +0200 Subject: [PATCH] feat: add `Choice` class for possible values of hyperparameter (#325) ### Summary of Changes Add a class to represent possible choices for the value of a hyperparameter. This is in preparation for #264. --- src/safeds/ml/hyperparameters/__init__.py | 5 ++ src/safeds/ml/hyperparameters/_choice.py | 63 +++++++++++++++++++ tests/safeds/ml/hyperparameters/__init__.py | 0 .../safeds/ml/hyperparameters/test_choice.py | 63 +++++++++++++++++++ 4 files changed, 131 insertions(+) create mode 100644 src/safeds/ml/hyperparameters/__init__.py create mode 100644 src/safeds/ml/hyperparameters/_choice.py create mode 100644 tests/safeds/ml/hyperparameters/__init__.py create mode 100644 tests/safeds/ml/hyperparameters/test_choice.py diff --git a/src/safeds/ml/hyperparameters/__init__.py b/src/safeds/ml/hyperparameters/__init__.py new file mode 100644 index 000000000..67291eba7 --- /dev/null +++ b/src/safeds/ml/hyperparameters/__init__.py @@ -0,0 +1,5 @@ +"""Tools to work with hyperparameters of ML models.""" + +from ._choice import Choice + +__all__ = ["Choice"] diff --git a/src/safeds/ml/hyperparameters/_choice.py b/src/safeds/ml/hyperparameters/_choice.py new file mode 100644 index 000000000..4eb612804 --- /dev/null +++ b/src/safeds/ml/hyperparameters/_choice.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from collections.abc import Collection +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any + +T = TypeVar("T") + + +class Choice(Collection[T]): + """A list of values to choose from in a hyperparameter search.""" + + def __init__(self, *args: T) -> None: + """ + Create a new choice. + + Parameters + ---------- + *args: tuple[T, ...] + The values to choose from. + """ + self.elements = list(args) + + def __contains__(self, value: Any) -> bool: + """ + Check if a value is in this choice. + + Parameters + ---------- + value: Any + The value to check. + + Returns + ------- + is_in_choice : bool + Whether the value is in this choice. + """ + return value in self.elements + + def __iter__(self) -> Iterator[T]: + """ + Iterate over the values of this choice. + + Returns + ------- + iterator : Iterator[T] + An iterator over the values of this choice. + """ + return iter(self.elements) + + def __len__(self) -> int: + """ + Get the number of values in this choice. + + Returns + ------- + number_of_values : int + The number of values in this choice. + """ + return len(self.elements) diff --git a/tests/safeds/ml/hyperparameters/__init__.py b/tests/safeds/ml/hyperparameters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/safeds/ml/hyperparameters/test_choice.py b/tests/safeds/ml/hyperparameters/test_choice.py new file mode 100644 index 000000000..8adcd5952 --- /dev/null +++ b/tests/safeds/ml/hyperparameters/test_choice.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from safeds.ml.hyperparameters import Choice + +if TYPE_CHECKING: + from typing import Any + + +class TestContains: + @pytest.mark.parametrize( + ("choice", "value", "expected"), + [ + (Choice(1, 2, 3), 1, True), + (Choice(1, 2, 3), 2, True), + (Choice(1, 2, 3), 3, True), + (Choice(1, 2, 3), 4, False), + (Choice(1, 2, 3), "3", False), + ], + ids=[ + "value in choice (start)", + "value in choice (middle)", + "value in choice (end)", + "value not in choice", + "value not in choice (wrong type)", + ], + ) + def test_should_check_whether_choice_contains_value(self, choice: Choice, value: Any, expected: bool) -> None: + assert (value in choice) == expected + + +class TestIter: + @pytest.mark.parametrize( + ("choice", "expected"), + [ + (Choice(), []), + (Choice(1, 2, 3), [1, 2, 3]), + ], + ids=[ + "empty", + "non-empty", + ], + ) + def test_should_iterate_values(self, choice: Choice, expected: list[Any]) -> None: + assert list(choice) == expected + + +class TestLen: + @pytest.mark.parametrize( + ("choice", "expected"), + [ + (Choice(), 0), + (Choice(1, 2, 3), 3), + ], + ids=[ + "empty", + "non-empty", + ], + ) + def test_should_return_number_of_values(self, choice: Choice, expected: int) -> None: + assert len(choice) == expected