-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
Choice
class for possible values of hyperparameter (#325)
### Summary of Changes Add a class to represent possible choices for the value of a hyperparameter. This is in preparation for #264.
- Loading branch information
1 parent
ca046c4
commit d511c3e
Showing
4 changed files
with
131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Tools to work with hyperparameters of ML models.""" | ||
|
||
from ._choice import Choice | ||
|
||
__all__ = ["Choice"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Collection | ||
from typing import TYPE_CHECKING, TypeVar | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterator | ||
from typing import Any | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class Choice(Collection[T]): | ||
"""A list of values to choose from in a hyperparameter search.""" | ||
|
||
def __init__(self, *args: T) -> None: | ||
""" | ||
Create a new choice. | ||
Parameters | ||
---------- | ||
*args: tuple[T, ...] | ||
The values to choose from. | ||
""" | ||
self.elements = list(args) | ||
|
||
def __contains__(self, value: Any) -> bool: | ||
""" | ||
Check if a value is in this choice. | ||
Parameters | ||
---------- | ||
value: Any | ||
The value to check. | ||
Returns | ||
------- | ||
is_in_choice : bool | ||
Whether the value is in this choice. | ||
""" | ||
return value in self.elements | ||
|
||
def __iter__(self) -> Iterator[T]: | ||
""" | ||
Iterate over the values of this choice. | ||
Returns | ||
------- | ||
iterator : Iterator[T] | ||
An iterator over the values of this choice. | ||
""" | ||
return iter(self.elements) | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Get the number of values in this choice. | ||
Returns | ||
------- | ||
number_of_values : int | ||
The number of values in this choice. | ||
""" | ||
return len(self.elements) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import pytest | ||
from safeds.ml.hyperparameters import Choice | ||
|
||
if TYPE_CHECKING: | ||
from typing import Any | ||
|
||
|
||
class TestContains: | ||
@pytest.mark.parametrize( | ||
("choice", "value", "expected"), | ||
[ | ||
(Choice(1, 2, 3), 1, True), | ||
(Choice(1, 2, 3), 2, True), | ||
(Choice(1, 2, 3), 3, True), | ||
(Choice(1, 2, 3), 4, False), | ||
(Choice(1, 2, 3), "3", False), | ||
], | ||
ids=[ | ||
"value in choice (start)", | ||
"value in choice (middle)", | ||
"value in choice (end)", | ||
"value not in choice", | ||
"value not in choice (wrong type)", | ||
], | ||
) | ||
def test_should_check_whether_choice_contains_value(self, choice: Choice, value: Any, expected: bool) -> None: | ||
assert (value in choice) == expected | ||
|
||
|
||
class TestIter: | ||
@pytest.mark.parametrize( | ||
("choice", "expected"), | ||
[ | ||
(Choice(), []), | ||
(Choice(1, 2, 3), [1, 2, 3]), | ||
], | ||
ids=[ | ||
"empty", | ||
"non-empty", | ||
], | ||
) | ||
def test_should_iterate_values(self, choice: Choice, expected: list[Any]) -> None: | ||
assert list(choice) == expected | ||
|
||
|
||
class TestLen: | ||
@pytest.mark.parametrize( | ||
("choice", "expected"), | ||
[ | ||
(Choice(), 0), | ||
(Choice(1, 2, 3), 3), | ||
], | ||
ids=[ | ||
"empty", | ||
"non-empty", | ||
], | ||
) | ||
def test_should_return_number_of_values(self, choice: Choice, expected: int) -> None: | ||
assert len(choice) == expected |