Skip to content

Commit

Permalink
Correct accepted types for Input.keep_directories
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Nov 4, 2021
1 parent eaff2f1 commit 7743b5c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
14 changes: 13 additions & 1 deletion tests/test_step_parsing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from valohai_yaml.objs.input import KeepDirectories
import pytest

from valohai_yaml.objs.input import Input, KeepDirectories


def test_parse_inputs(example2_config):
Expand Down Expand Up @@ -127,3 +129,13 @@ def test_input_extras(input_extras_config):
assert step.inputs['model'].filename == "model.pb"
assert step.inputs['foos'].keep_directories == KeepDirectories.FULL
assert step.inputs['bars'].keep_directories == KeepDirectories.SUFFIX


@pytest.mark.parametrize("value, expected", [(kd, kd) for kd in KeepDirectories] + [
("full", KeepDirectories.FULL), # type: ignore[list-item]
(False, KeepDirectories.NONE), # type: ignore[list-item]
(True, KeepDirectories.FULL), # type: ignore[list-item]
("suffix", KeepDirectories.SUFFIX), # type: ignore[list-item]
])
def test_input_keep_directories(value, expected):
assert Input(name="foo", keep_directories=value).keep_directories == expected
8 changes: 6 additions & 2 deletions valohai_yaml/objs/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from .base import Item

KeepDirectoriesValue = Union[bool, str, 'KeepDirectories']


class KeepDirectories(Enum):
"""How to retain directories when using storage wildcards."""
Expand All @@ -12,7 +14,9 @@ class KeepDirectories(Enum):
FULL = 'full'

@classmethod
def cast(cls, value: Union[bool, str]) -> 'KeepDirectories':
def cast(cls, value: KeepDirectoriesValue) -> 'KeepDirectories':
if isinstance(value, KeepDirectories):
return value
if not value:
return KeepDirectories.NONE
if value is True:
Expand All @@ -30,7 +34,7 @@ def __init__(
default: Optional[Union[List[str], str]] = None,
optional: bool = False,
description: Optional[str] = None,
keep_directories: bool = False,
keep_directories: KeepDirectoriesValue = False,
filename: Optional[str] = None
) -> None:
self.name = name
Expand Down

0 comments on commit 7743b5c

Please sign in to comment.