Skip to content

Commit

Permalink
Merge pull request #17 from Wytamma/add-dict-type
Browse files Browse the repository at this point in the history
Add dict type
  • Loading branch information
Wytamma authored Nov 17, 2024
2 parents 945f27c + 06ceeca commit ee5dc63
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 10 deletions.
29 changes: 24 additions & 5 deletions src/snk_cli/dynamic_typer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import typer
from typing import List, Callable
from click import Tuple
from typing import List, Callable, get_origin
from inspect import signature, Parameter
from makefun import with_signature
from enum import Enum

from .options import Option
import sys


def parse_colon_separated_pair(value: str):
return tuple(value.split(sep=':', maxsplit=2))

class DynamicTyper:
app: typer.Typer

Expand Down Expand Up @@ -150,12 +153,19 @@ def _create_cli_parameter(self, option: Option):
annotation_type = option.type
default = option.default
if option.choices:
annotation_type = Enum(f'{option.name}', {str(e): annotation_type(e) for e in option.choices})
if default:
try:
default = annotation_type(default)
except ValueError:
raise ValueError(f"Default value '{default}' for option '{option.name}' is not a valid choice.")
annotation_type = Enum(f'{option.name}', {str(e): annotation_type(e) for e in option.choices})
metavar, parser = None, None
if get_origin(annotation_type) is dict or annotation_type is dict:
metavar = "KEY:VALUE"
parser = parse_colon_separated_pair
annotation_type = List[Tuple]
if type(default) is dict:
default = [f"{k}:{v}" for k, v in default.items()]
return Parameter(
option.name,
kind=Parameter.POSITIONAL_OR_KEYWORD,
Expand All @@ -165,6 +175,8 @@ def _create_cli_parameter(self, option: Option):
help=f"{option.help}",
rich_help_panel="Workflow Configuration",
hidden=option.hidden,
metavar=metavar,
parser=parser,
),
annotation=annotation_type,
)
Expand Down Expand Up @@ -236,8 +248,15 @@ def func_wrapper(*args, **kwargs):
for snk_cli_option in options:

def add_option_to_args():
kwargs["ctx"].args.extend([f"--{snk_cli_option.name}", kwargs[snk_cli_option.name]])

value = kwargs[snk_cli_option.name]
if (snk_cli_option.type is dict or get_origin(snk_cli_option.type) is dict) and isinstance(value, list):
# Convert the list of tuples to a dictionary
value = dict(kwargs[snk_cli_option.name])
if get_origin(snk_cli_option.type) is dict:
# get the value type from the type hint
value_type = snk_cli_option.type.__args__[1]
value = {k: value_type(v) for k, v in value.items()}
kwargs["ctx"].args.extend([f"--{snk_cli_option.name}", value])
passed_via_command_line = self.check_if_option_passed_via_command_line(
snk_cli_option
)
Expand Down
5 changes: 4 additions & 1 deletion src/snk_cli/options/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
"list[int]": List[int],
"list[float]": List[float],
"pair": Tuple[str, str],
"dict": dict,
"dict[str, str]": dict[str, str],
"dict[str, int]": dict[str, int],
}

# Define the basic types for the combinations
Expand Down Expand Up @@ -111,9 +114,9 @@ def build_dynamic_cli_options(
Returns:
List[dict]: A list of options.
"""
flat_config = flatten(snakemake_config)
flat_annotations = flatten(snk_config.cli)
annotation_keys = get_keys_from_annotation(flat_annotations)
flat_config = flatten(snakemake_config, stop_at=annotation_keys)
options = {}

# For every parameter in the config, create an option from the corresponding annotation
Expand Down
9 changes: 6 additions & 3 deletions src/snk_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def check_command_available(command: str):
return which(command) is not None


def flatten(d, parent_key="", sep=":"):
def flatten(d, parent_key="", sep=":", stop_at=[]):
"""
Flattens a nested dictionary.
Expand All @@ -50,8 +50,11 @@ def flatten(d, parent_key="", sep=":"):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
if new_key in stop_at:
# this allows dict to be flattened up to a certain level so we can have dict types
items.append((new_key, v))
elif isinstance(v, MutableMapping):
items.extend(flatten(v, new_key, sep=sep, stop_at=stop_at).items())
else:
items.append((new_key, v))
return dict(items)
Expand Down
5 changes: 5 additions & 0 deletions tests/data/workflow/snk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ cli:
type: pair
enum:
choices: [a, b, c]
default: "a"
dict:
help: A dictionary
type: dict[str, int]
default: ["key:1"]
test:
another:
test:
Expand Down
42 changes: 41 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,51 @@ def test_pair(snakemake_config, annotations, cli_args, expected):
{"example": {"type": "int", "choices": [1, 2, 3]}},
["--example", 1],
"{'example': 1}"
)
),
(
{"example": 1},
{"example": {"type": "int", "choices": [2, 3], "default": 2}},
[],
"{'example': 2}"
),
])
def test_choices(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# dict
(
{"example": {"key": "value"}},
{"example": {"type": "dict", "default": ["key:value"]}},
[],
"{'example': {'key': 'value'}}"
),
(
{"example": {"key": "value"}},
{"example": {"type": "dict"}},
[],
"{'example': {'key': 'value'}}"
),
(
{"example": {"number": 1}},
{"example": {"type": "dict[str, int]"}},
[],
"{'example': {'number': 1}}"
),
(
{},
{"example": {"type": "dict[str, str]"}},
["--example", "new:2"],
"{'example': {'new': '2'}}"
)
])
def test_dict(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

0 comments on commit ee5dc63

Please sign in to comment.