Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Choice class for possible values of hyperparameter #325

Merged
merged 2 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/safeds/ml/hyperparameters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Tools to work with hyperparameters of ML models."""

from ._choice import Choice

__all__ = ["Choice"]
63 changes: 63 additions & 0 deletions src/safeds/ml/hyperparameters/_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from collections.abc import Collection
from typing import TYPE_CHECKING, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any

T = TypeVar("T")


class Choice(Collection[T]):
"""A list of values to choose from in a hyperparameter search."""

def __init__(self, *args: T) -> None:
"""
Create a new choice.

Parameters
----------
*args: tuple[T, ...]
The values to choose from.
"""
self.elements = list(args)

def __contains__(self, value: Any) -> bool:
"""
Check if a value is in this choice.

Parameters
----------
value: Any
The value to check.

Returns
-------
is_in_choice : bool
Whether the value is in this choice.
"""
return value in self.elements

def __iter__(self) -> Iterator[T]:
"""
Iterate over the values of this choice.

Returns
-------
iterator : Iterator[T]
An iterator over the values of this choice.
"""
return iter(self.elements)

def __len__(self) -> int:
"""
Get the number of values in this choice.

Returns
-------
number_of_values : int
The number of values in this choice.
"""
return len(self.elements)
Empty file.
63 changes: 63 additions & 0 deletions tests/safeds/ml/hyperparameters/test_choice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest
from safeds.ml.hyperparameters import Choice

if TYPE_CHECKING:
from typing import Any


class TestContains:
@pytest.mark.parametrize(
("choice", "value", "expected"),
[
(Choice(1, 2, 3), 1, True),
(Choice(1, 2, 3), 2, True),
(Choice(1, 2, 3), 3, True),
(Choice(1, 2, 3), 4, False),
(Choice(1, 2, 3), "3", False),
],
ids=[
"value in choice (start)",
"value in choice (middle)",
"value in choice (end)",
"value not in choice",
"value not in choice (wrong type)",
],
)
def test_should_check_whether_choice_contains_value(self, choice: Choice, value: Any, expected: bool) -> None:
assert (value in choice) == expected


class TestIter:
@pytest.mark.parametrize(
("choice", "expected"),
[
(Choice(), []),
(Choice(1, 2, 3), [1, 2, 3]),
],
ids=[
"empty",
"non-empty",
],
)
def test_should_iterate_values(self, choice: Choice, expected: list[Any]) -> None:
assert list(choice) == expected


class TestLen:
@pytest.mark.parametrize(
("choice", "expected"),
[
(Choice(), 0),
(Choice(1, 2, 3), 3),
],
ids=[
"empty",
"non-empty",
],
)
def test_should_return_number_of_values(self, choice: Choice, expected: int) -> None:
assert len(choice) == expected