Skip to content

Commit

Permalink
🐛 fix choices
Browse files Browse the repository at this point in the history
  • Loading branch information
Wytamma committed Nov 17, 2024
1 parent 1c07aa6 commit 945f27c
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 7 deletions.
8 changes: 3 additions & 5 deletions src/snk_cli/dynamic_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,13 @@ def _create_cli_parameter(self, option: Option):
"""
annotation_type = option.type
default = option.default
if option.type is Enum or option.choices:
if not option.choices:
raise ValueError(f"Enum type {option.name} requires choices to be defined.")
annotation_type = Enum(f'{option.name}', {str(e): str(e) for e in option.choices})
if option.choices:
annotation_type = Enum(f'{option.name}', {str(e): annotation_type(e) for e in option.choices})
if default:
try:
default = annotation_type(default)
except ValueError:
raise ValueError(f"Default value {default} for option {option.name} is not a valid choice.")
raise ValueError(f"Default value '{default}' for option '{option.name}' is not a valid choice.")
return Parameter(
option.name,
kind=Parameter.POSITIONAL_OR_KEYWORD,
Expand Down
8 changes: 7 additions & 1 deletion src/snk_cli/options/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def create_option_from_annotation(
updated = False
if config_default is None or default != config_default:
updated = True
annotation_type = annotation_values.get(f"{annotation_key}:type", get_default_type(default)).lower()
annotation_type = annotation_values.get(f"{annotation_key}:type", None)
if annotation_type is not None:
assert annotation_type in types, f"Type '{annotation_type}' not supported."
annotation_type = annotation_type or get_default_type(default)
annotation_type = types.get(
annotation_type, List[str] if "list" in annotation_type else str
)
Expand All @@ -75,6 +78,9 @@ def create_option_from_annotation(
default=annotation_values.get(f"{annotation_key}:default", default)
if default and get_origin(annotation_type) is tuple:
assert len(default) == 2, f"Default value ({default}) for '{annotation_key}' should be a list of length 2."
choices = annotation_values.get(f"{annotation_key}:choices", None)
if choices:
assert isinstance(choices, list), f"Choices should be a list for '{annotation_key}'."
return Option(
name=name,
original_key=annotation_key,
Expand Down
4 changes: 4 additions & 0 deletions src/snk_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def serialise(d):
if isinstance(d, Path) or isinstance(d, datetime):
return str(d)

# check enum type
if hasattr(d, "value"):
return d.value

if isinstance(d, tuple):
return list(d)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli/test_snk_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_non_standard_configfile(tmp_path):
assert "config2" in res.stdout, res.stderr

def test_snk_config_with_enums(tmp_path):
runner = dynamic_runner({}, SnkConfig(cli={"test": {"choices": ["enum1", "enum2"], "type": "enum"}}), tmp_path=tmp_path)
runner = dynamic_runner({}, SnkConfig(cli={"test": {"choices": ["enum1", "enum2"], "type": "str"}}), tmp_path=tmp_path)
res = runner.invoke(["run", "--help"])
assert res.exit_code == 0, res.stderr
assert "enum1" in res.stdout, res.stderr
Expand Down
16 changes: 16 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,19 @@ def test_pair(snakemake_config, annotations, cli_args, expected):
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# choices
(
{"example": 1},
{"example": {"type": "int", "choices": [1, 2, 3]}},
["--example", 1],
"{'example': 1}"
)
])
def test_choices(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

0 comments on commit 945f27c

Please sign in to comment.