Skip to content

Commit

Permalink
CLI ignore external parser list fix (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab authored Sep 7, 2024
1 parent 818d56e commit 94f3a90
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 27 deletions.
62 changes: 38 additions & 24 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from enum import Enum
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1155,13 +1156,15 @@ def __call__(self, *, args: list[str] | tuple[str, ...] | bool) -> CliSettingsSo
...

@overload
def __call__(self, *, parsed_args: Namespace | dict[str, list[str] | str]) -> CliSettingsSource[T]:
def __call__(
self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str]
) -> CliSettingsSource[T]:
"""
Loads parsed command line arguments into the CLI settings source.
Note:
The parsed args must be in `argparse.Namespace` or vars dictionary (e.g., vars(argparse.Namespace))
format.
The parsed args must be in `argparse.Namespace`, `SimpleNamespace`, or vars dictionary
(e.g., vars(argparse.Namespace)) format.
Args:
parsed_args: The parsed args to load.
Expand All @@ -1175,7 +1178,7 @@ def __call__(
self,
*,
args: list[str] | tuple[str, ...] | bool | None = None,
parsed_args: Namespace | dict[str, list[str] | str] | None = None,
parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str] | None = None,
) -> dict[str, Any] | CliSettingsSource[T]:
if args is not None and parsed_args is not None:
raise SettingsError('`args` and `parsed_args` are mutually exclusive')
Expand All @@ -1194,13 +1197,15 @@ def __call__(
def _load_env_vars(self) -> Mapping[str, str | None]: ...

@overload
def _load_env_vars(self, *, parsed_args: Namespace | dict[str, list[str] | str]) -> CliSettingsSource[T]:
def _load_env_vars(
self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str]
) -> CliSettingsSource[T]:
"""
Loads the parsed command line arguments into the CLI environment settings variables.
Note:
The parsed args must be in `argparse.Namespace` or vars dictionary (e.g., vars(argparse.Namespace))
format.
The parsed args must be in `argparse.Namespace`, `SimpleNamespace`, or vars dictionary
(e.g., vars(argparse.Namespace)) format.
Args:
parsed_args: The parsed args to load.
Expand All @@ -1211,12 +1216,12 @@ def _load_env_vars(self, *, parsed_args: Namespace | dict[str, list[str] | str])
...

def _load_env_vars(
self, *, parsed_args: Namespace | dict[str, list[str] | str] | None = None
self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str] | None = None
) -> Mapping[str, str | None] | CliSettingsSource[T]:
if parsed_args is None:
return {}

if isinstance(parsed_args, Namespace):
if isinstance(parsed_args, (Namespace, SimpleNamespace)):
parsed_args = vars(parsed_args)

selected_subcommands: list[str] = []
Expand Down Expand Up @@ -1246,26 +1251,35 @@ def _load_env_vars(

return self

def _get_merge_parsed_list_types(
self, parsed_list: list[str], field_name: str
) -> tuple[Optional[type], Optional[type]]:
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
if type_ is not type(None) and get_origin(type_) not in (dict, Mapping)
)
):
inferred_type = merge_type
else:
inferred_type = list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str

return merge_type, inferred_type

def _merge_parsed_list(self, parsed_list: list[str], field_name: str) -> str:
try:
merged_list: list[str] = []
is_last_consumed_a_value = False
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
if type_ is not type(None) and get_origin(type_) not in (dict, Mapping)
)
):
inferred_type = merge_type
else:
inferred_type = (
list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str
)
merge_type, inferred_type = self._get_merge_parsed_list_types(parsed_list, field_name)
for val in parsed_list:
if not isinstance(val, str):
# If val is not a string, it's from an external parser and we can ignore parsing the rest of the
# list.
break
val = val.strip()
if val.startswith('[') and val.endswith(']'):
val = val[1:-1].strip()
Expand Down
36 changes: 33 additions & 3 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3665,21 +3665,51 @@ class Cfg(BaseSettings):
cli_cfg_settings = CliSettingsSource(Cfg, cli_prefix=prefix, root_parser=parser)

add_arg('--fruit', choices=['pear', 'kiwi', 'lime'])
add_arg('--num-list', action='append', type=int)
add_arg('--num', type=int)

args = ['--fruit', 'pear']
args = ['--fruit', 'pear', '--num', '0', '--num-list', '1', '--num-list', '2', '--num-list', '3']
parsed_args = parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {'pet': 'bird'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {'pet': 'bird'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'}

arg_prefix = f'{prefix}.' if prefix else ''
args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog']
args = [
'--fruit',
'kiwi',
'--num',
'0',
'--num-list',
'1',
'--num-list',
'2',
'--num-list',
'3',
f'--{arg_prefix}pet',
'dog',
]
parsed_args = parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {'pet': 'dog'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {'pet': 'dog'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'}

parsed_args = parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat'])
parsed_args = parse_args(
[
'--fruit',
'kiwi',
'--num',
'0',
'--num-list',
'1',
'--num-list',
'2',
'--num-list',
'3',
f'--{arg_prefix}pet',
'cat',
]
)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == {'pet': 'cat'}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=False)).model_dump() == {'pet': 'bird'}

Expand Down

0 comments on commit 94f3a90

Please sign in to comment.