Skip to content

Commit

Permalink
Make parsing parameter from CLI work properly when it is a list (#128)
Browse files Browse the repository at this point in the history
When the values of a parameter are provided via CLI, split them into a list properly with the default separator is a comma
  • Loading branch information
dangquangdon authored Apr 22, 2024
1 parent 681890e commit 88d89a5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_prepare(tmpdir, monkeypatch):
"makemeqwer": "asdf",
"makeme321": 123,
"makemenegative": 0.0001,
"list_bar": ["bar1", "bar2"],
}
url1 = "https://dist.valohai.com/valohai-utils-tests/Example.svg"
url2 = "https://dist.valohai.com/valohai-utils-tests/sharktavern.jpg"
Expand All @@ -51,6 +52,7 @@ def test_prepare(tmpdir, monkeypatch):
"--makemenegative=-0.123",
"--some_totally_random_parameter_to_ignore=666",
f"--overrideme={str(local_file)}",
"--list_bar=bar1,bar2,bar3",
]
m.setattr(
sys,
Expand All @@ -70,6 +72,7 @@ def test_prepare(tmpdir, monkeypatch):
assert valohai.parameters("makemeqwer").value == "qwer"
assert valohai.parameters("makeme321").value == 321
assert valohai.parameters("makemenegative").value < 0.0
assert valohai.parameters("list_bar").value == ["bar1", "bar2", "bar3"]

assert (
get_input_info("example").files[0].uri
Expand Down
2 changes: 2 additions & 0 deletions valohai/internals/global_state_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def parse_overrides_from_cli(
if isinstance(value, bool):
# We need to fiddle booleans in a bit different way, since they are treated as flags per default
parser.add_argument(f"--{name}", type=string_to_bool, nargs="?", const=True)
elif isinstance(value, list):
parser.add_argument(f"--{name}", type=lambda s: str(s).split(","))
else:
parser.add_argument(f"--{name}", type=type(value))
known_args, unknown_args = parser.parse_known_args()
Expand Down

0 comments on commit 88d89a5

Please sign in to comment.