diff --git a/src/snk_cli/dynamic_typer.py b/src/snk_cli/dynamic_typer.py index c411d4c..ae0e560 100644 --- a/src/snk_cli/dynamic_typer.py +++ b/src/snk_cli/dynamic_typer.py @@ -149,15 +149,13 @@ def _create_cli_parameter(self, option: Option): """ annotation_type = option.type default = option.default - if option.type is Enum or option.choices: - if not option.choices: - raise ValueError(f"Enum type {option.name} requires choices to be defined.") - annotation_type = Enum(f'{option.name}', {str(e): str(e) for e in option.choices}) + if option.choices: + annotation_type = Enum(f'{option.name}', {str(e): annotation_type(e) for e in option.choices}) if default: try: default = annotation_type(default) except ValueError: - raise ValueError(f"Default value {default} for option {option.name} is not a valid choice.") + raise ValueError(f"Default value '{default}' for option '{option.name}' is not a valid choice.") return Parameter( option.name, kind=Parameter.POSITIONAL_OR_KEYWORD, diff --git a/src/snk_cli/options/utils.py b/src/snk_cli/options/utils.py index b7a599b..b58c65d 100644 --- a/src/snk_cli/options/utils.py +++ b/src/snk_cli/options/utils.py @@ -63,7 +63,10 @@ def create_option_from_annotation( updated = False if config_default is None or default != config_default: updated = True - annotation_type = annotation_values.get(f"{annotation_key}:type", get_default_type(default)).lower() + annotation_type = annotation_values.get(f"{annotation_key}:type", None) + if annotation_type is not None: + assert annotation_type in types, f"Type '{annotation_type}' not supported." + annotation_type = annotation_type or get_default_type(default) annotation_type = types.get( annotation_type, List[str] if "list" in annotation_type else str ) @@ -75,6 +78,9 @@ def create_option_from_annotation( default=annotation_values.get(f"{annotation_key}:default", default) if default and get_origin(annotation_type) is tuple: assert len(default) == 2, f"Default value ({default}) for '{annotation_key}' should be a list of length 2." + choices = annotation_values.get(f"{annotation_key}:choices", None) + if choices: + assert isinstance(choices, list), f"Choices should be a list for '{annotation_key}'." return Option( name=name, original_key=annotation_key, diff --git a/src/snk_cli/utils.py b/src/snk_cli/utils.py index 53658a9..d1189b2 100644 --- a/src/snk_cli/utils.py +++ b/src/snk_cli/utils.py @@ -98,6 +98,10 @@ def serialise(d): if isinstance(d, Path) or isinstance(d, datetime): return str(d) + # check enum type + if hasattr(d, "value"): + return d.value + if isinstance(d, tuple): return list(d) diff --git a/tests/test_cli/test_snk_config.py b/tests/test_cli/test_snk_config.py index 3d2e8a3..cc78c90 100644 --- a/tests/test_cli/test_snk_config.py +++ b/tests/test_cli/test_snk_config.py @@ -44,7 +44,7 @@ def test_non_standard_configfile(tmp_path): assert "config2" in res.stdout, res.stderr def test_snk_config_with_enums(tmp_path): - runner = dynamic_runner({}, SnkConfig(cli={"test": {"choices": ["enum1", "enum2"], "type": "enum"}}), tmp_path=tmp_path) + runner = dynamic_runner({}, SnkConfig(cli={"test": {"choices": ["enum1", "enum2"], "type": "str"}}), tmp_path=tmp_path) res = runner.invoke(["run", "--help"]) assert res.exit_code == 0, res.stderr assert "enum1" in res.stdout, res.stderr diff --git a/tests/test_types.py b/tests/test_types.py index b903fa9..7761a8a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -140,3 +140,19 @@ def test_pair(snakemake_config, annotations, cli_args, expected): res = runner.invoke(["run"] + cli_args) assert res.exit_code == 0, res.stderr assert expected in res.stdout, res.stderr + +@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [ + # choices + ( + {"example": 1}, + {"example": {"type": "int", "choices": [1, 2, 3]}}, + ["--example", 1], + "{'example': 1}" + ) + ]) +def test_choices(snakemake_config, annotations, cli_args, expected): + snk_config = SnkConfig(cli=annotations) + runner = dynamic_runner(snakemake_config, snk_config) + res = runner.invoke(["run"] + cli_args) + assert res.exit_code == 0, res.stderr + assert expected in res.stdout, res.stderr