Skip to content

Commit

Permalink
Merge pull request #16 from Wytamma/add-pair-type
Browse files Browse the repository at this point in the history
Add pair type
  • Loading branch information
Wytamma authored Nov 17, 2024
2 parents a969006 + c77aa09 commit 1c07aa6
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 12 deletions.
26 changes: 17 additions & 9 deletions src/snk_cli/options/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple, get_origin
from ..config.config import SnkConfig
from ..utils import get_default_type, flatten
from .option import Option
Expand All @@ -18,9 +18,18 @@
"list[str]": List[str],
"list[path]": List[Path],
"list[int]": List[int],
"enum": Enum,
"list[float]": List[float],
"pair": Tuple[str, str],
}

# Define the basic types for the combinations
basic_types = [int, str, bool, float]

# Add the combinations of the basic types to the `types` dictionary
for t1 in basic_types:
for t2 in basic_types:
types[f"pair[{t1.__name__}, {t2.__name__}]"] = Tuple[t1, t2]

def get_keys_from_annotation(annotations):
# Get the unique keys from the annotations
# preserving the order
Expand Down Expand Up @@ -50,27 +59,26 @@ def create_option_from_annotation(
Option: An Option object.
"""
config_default = default_values.get(annotation_key, None)

default = annotation_values.get(f"{annotation_key}:default", config_default)
updated = False
if config_default is None or default != config_default:
updated = True
type = annotation_values.get(f"{annotation_key}:type", get_default_type(default))
assert (
type is not None
), f"Type for {annotation_key} should be one of {', '.join(types.keys())}."
annotation_type = annotation_values.get(f"{annotation_key}:type", get_default_type(default)).lower()
annotation_type = types.get(
type.lower(), List[str] if "list" in type.lower() else str
annotation_type, List[str] if "list" in annotation_type else str
)
name = annotation_values.get(
f"{annotation_key}:name", annotation_key.replace(":", "_")
).replace("-", "_")
short = annotation_values.get(f"{annotation_key}:short", None)
hidden = annotation_values.get(f"{annotation_key}:hidden", False)
default=annotation_values.get(f"{annotation_key}:default", default)
if default and get_origin(annotation_type) is tuple:
assert len(default) == 2, f"Default value ({default}) for '{annotation_key}' should be a list of length 2."
return Option(
name=name,
original_key=annotation_key,
default=annotation_values.get(f"{annotation_key}:default", default),
default=default,
updated=updated,
help=annotation_values.get(f"{annotation_key}:help", ""),
type=annotation_type,
Expand Down
5 changes: 4 additions & 1 deletion src/snk_cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def serialise(d):
"""
if isinstance(d, Path) or isinstance(d, datetime):
return str(d)

if isinstance(d, tuple):
return list(d)

if isinstance(d, list):
return [serialise(x) for x in d]
Expand Down Expand Up @@ -155,7 +158,7 @@ def parse_config_args(args: List[str], options: List[Option]):

def get_default_type(v):
default_type = type(v)
if default_type == list and len(v) > 0:
if default_type is list and len(v) > 0:
return f"List[{type(v[0]).__name__}]"
return str(default_type.__name__)

Expand Down
9 changes: 7 additions & 2 deletions src/snk_cli/validate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Union, get_origin
from .config import SnkConfig
from .options.utils import types
import inspect
Expand Down Expand Up @@ -56,11 +56,16 @@ def validate_and_transform_in_place(config: Dict[str, Any], validation: Validati
if val_type is None:
raise ValueError(f"Unknown type '{val_info['type']}'")
try:
if getattr(val_type, "__origin__", None) == list:
if getattr(val_type, "__origin__", None) is list:
val_type = val_type.__args__[0]
if not isinstance(value, list):
raise ValueError(f"Expected a list for key '{key}'")
config[key] = [val_type(v) for v in value]
elif get_origin(val_type) is tuple:
assert len(value) == 2, f"Expected a list of length 2 for key '{key}'"
key_type = val_type.__args__[0]
val_type = val_type.__args__[1]
config[key] = [key_type(value[0]), val_type(value[1])]
else:
config[key] = val_type(value)
except (ValueError, TypeError) as e:
Expand Down
4 changes: 4 additions & 0 deletions tests/data/workflow/snk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ cli:
type: bool
null_annotation:
default: null
KeyValuePair:
default: ["key", "value"]
help: A key-value pair
type: pair
enum:
choices: [a, b, c]
test:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_cli/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@
{"example": {"type": "list[path]"}},
{"example": [Path("1"), Path("2")]}
),
# pair type
(
{"example": ["1", "2"]},
{"example": {"type": "pair[int, int]"}},
{"example": [1, 2]}
),
# choices
(
{"example": "a"},
{"example": {"type": "str", "choices": ["a", "b"]}},
{"example": "a"}
),
# nested dictionary
(
{"example": {"nested": "1"}},
Expand Down
142 changes: 142 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# integration tests for the annotation types from the snk config, snakemake config, and the command line interface

import pytest
from .utils import dynamic_runner
from snk_cli.config import SnkConfig


@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# str
(
{"example": "1"},
{"example": {"type": "str"}},
["--example", "1"],
"{'example': '1'}"
)])
def test_str(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# int
(
{"example": "1"},
{"example": {"type": "int"}},
["--example", "1"],
"{'example': 1}"
)])
def test_int(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# float
(
{"example": "1"},
{"example": {"type": "float"}},
["--example", "1"],
"{'example': 1.0}"
)])
def test_float(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# bool
(
{"example": "1"},
{"example": {"type": "bool"}},
["--example"],
"{'example': True}"
)])
def test_bool(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# path
(
{"example": "file"},
{"example": {"type": "path"}},
["--example", "file"],
"{'example': 'file'}"
)])
def test_path(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# list
(
{"example": [1,2,3]},
{"example": {"type": "list"}},
["--example", "1", "--example", "2", "--example", "3"],
"{'example': ['1', '2', '3']}"
)])
def test_list(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

@pytest.mark.parametrize("snakemake_config,annotations,cli_args,expected", [
# pair
(
{"example": [1, 2]},
{"example": {"type": "pair"}},
["--example", "1", "2"],
"{'example': ['1', '2']}"
),
(
{"example": [1, 2]},
{"example": {"type": "pair[int, int]"}},
["--example", "1", "2"],
"{'example': [1, 2]}"
),
(
{"example": ["1", "2"]},
{"example": {"type": "pair[float, float]"}},
["--example", "1", "2"],
"{'example': [1.0, 2.0]}"
),
(
{"example": ["1", "2"]},
{"example": {"type": "pair[str, str]"}},
["--example", "1", "2"],
"{'example': ['1', '2']}"
),
(
{"example": ["1", "2"]},
{"example": {"type": "pair[str, int]"}},
["--example", "1", "2"],
"{'example': ['1', 2]}"
),
(
{"example": ["1", "2"]},
{"example": {"type": "pair[int, str]"}},
["--example", "1", "2"],
"{'example': [1, '2']}"
)
])
def test_pair(snakemake_config, annotations, cli_args, expected):
snk_config = SnkConfig(cli=annotations)
runner = dynamic_runner(snakemake_config, snk_config)
res = runner.invoke(["run"] + cli_args)
assert res.exit_code == 0, res.stderr
assert expected in res.stdout, res.stderr

0 comments on commit 1c07aa6

Please sign in to comment.