Skip to content

Commit

Permalink
feat: add Choice class for possible values of hyperparameter (#325)
Browse files Browse the repository at this point in the history
### Summary of Changes

Add a class to represent possible choices for the value of a
hyperparameter. This is in preparation for #264.
  • Loading branch information
lars-reimann authored May 26, 2023
1 parent ca046c4 commit d511c3e
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/safeds/ml/hyperparameters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Tools to work with hyperparameters of ML models."""

from ._choice import Choice

__all__ = ["Choice"]
63 changes: 63 additions & 0 deletions src/safeds/ml/hyperparameters/_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from collections.abc import Collection
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any

T = TypeVar("T")


class Choice(Collection[T]):
"""A list of values to choose from in a hyperparameter search."""

def __init__(self, *args: T) -> None:
"""
Create a new choice.
Parameters
----------
*args: tuple[T, ...]
The values to choose from.
"""
self.elements = list(args)

def __contains__(self, value: Any) -> bool:
"""
Check if a value is in this choice.
Parameters
----------
value: Any
The value to check.
Returns
-------
is_in_choice : bool
Whether the value is in this choice.
"""
return value in self.elements

def __iter__(self) -> Iterator[T]:
"""
Iterate over the values of this choice.
Returns
-------
iterator : Iterator[T]
An iterator over the values of this choice.
"""
return iter(self.elements)

def __len__(self) -> int:
"""
Get the number of values in this choice.
Returns
-------
number_of_values : int
The number of values in this choice.
"""
return len(self.elements)
Empty file.
63 changes: 63 additions & 0 deletions tests/safeds/ml/hyperparameters/test_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest
from safeds.ml.hyperparameters import Choice

if TYPE_CHECKING:
from typing import Any


class TestContains:
@pytest.mark.parametrize(
("choice", "value", "expected"),
[
(Choice(1, 2, 3), 1, True),
(Choice(1, 2, 3), 2, True),
(Choice(1, 2, 3), 3, True),
(Choice(1, 2, 3), 4, False),
(Choice(1, 2, 3), "3", False),
],
ids=[
"value in choice (start)",
"value in choice (middle)",
"value in choice (end)",
"value not in choice",
"value not in choice (wrong type)",
],
)
def test_should_check_whether_choice_contains_value(self, choice: Choice, value: Any, expected: bool) -> None:
assert (value in choice) == expected


class TestIter:
@pytest.mark.parametrize(
("choice", "expected"),
[
(Choice(), []),
(Choice(1, 2, 3), [1, 2, 3]),
],
ids=[
"empty",
"non-empty",
],
)
def test_should_iterate_values(self, choice: Choice, expected: list[Any]) -> None:
assert list(choice) == expected


class TestLen:
@pytest.mark.parametrize(
("choice", "expected"),
[
(Choice(), 0),
(Choice(1, 2, 3), 3),
],
ids=[
"empty",
"non-empty",
],
)
def test_should_return_number_of_values(self, choice: Choice, expected: int) -> None:
assert len(choice) == expected

0 comments on commit d511c3e

Please sign in to comment.