Skip to content

Commit

Permalink
Add CliSettingsSource alias handling for AliasChoices and AliasPath. (#…
Browse files Browse the repository at this point in the history
…313)

Co-authored-by: Hasan Ramezani <[email protected]>
  • Loading branch information
kschwab and hramezani authored Jun 13, 2024
1 parent abe7cc5 commit bd294a4
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 74 deletions.
35 changes: 35 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,38 @@ print(Settings().model_dump())
#> {'fruit': <Fruit.lime: 2>, 'pet': 'cat'}
```

#### Aliases

Pydantic field aliases are added as CLI argument aliases.

```py
import sys

from pydantic import AliasChoices, AliasPath, Field

from pydantic_settings import BaseSettings


class User(BaseSettings, cli_parse_args=True):
first_name: str = Field(
validation_alias=AliasChoices('fname', AliasPath('name', 0))
)
last_name: str = Field(validation_alias=AliasChoices('lname', AliasPath('name', 1)))


sys.argv = ['example.py', '--fname', 'John', '--lname', 'Doe']
print(User().model_dump())
#> {'first_name': 'John', 'last_name': 'Doe'}

sys.argv = ['example.py', '--name', 'John,Doe']
print(User().model_dump())
#> {'first_name': 'John', 'last_name': 'Doe'}

sys.argv = ['example.py', '--name', 'John', '--lname', 'Doe']
print(User().model_dump())
#> {'first_name': 'John', 'last_name': 'Doe'}
```

### Subcommands and Positional Arguments

Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These
Expand All @@ -681,6 +713,9 @@ subcommands must be a valid type derived from the pydantic `BaseModel` class.
set of subcommands. For more information on subparsers, see [argparse
subcommands](https://docs.python.org/3/library/argparse.html#sub-commands).

!!! note
`CliSubCommand` and `CliPositionalArg` are always case sensitive and do not support aliases.

```py
import sys

Expand Down
209 changes: 163 additions & 46 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,30 +1157,67 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
sub_models.append(type_) # type: ignore
return sub_models

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, str, FieldInfo]]:
def _get_resolved_names(
self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str]
) -> tuple[tuple[str, ...], bool]:
resolved_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
resolved_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
resolved_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
resolved_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
else:
new_alias_paths.append(alias)
for alias_path in new_alias_paths:
name = cast(str, alias_path.path[0])
name = name.lower() if not self.case_sensitive else name
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
if not resolved_names and is_alias_path_only:
resolved_names.append(name)
if not self.case_sensitive:
resolved_names = [resolved_name.lower() for resolved_name in resolved_names]
return tuple(dict.fromkeys(resolved_names)), is_alias_path_only

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_args, subcommand_args, optional_args = [], [], []
fields = model.__pydantic_fields__ if is_pydantic_dataclass(model) else model.model_fields
for field_name, field_info in fields.items():
resolved_name = field_name if field_info.alias is None else field_info.alias
resolved_name = resolved_name.lower() if not self.case_sensitive else resolved_name
if _CliSubCommand in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
elif any((field_info.alias, field_info.validation_alias)):
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has an alias')
else:
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
if len(field_types) != 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple types')
elif not is_model_class(field_types[0]):
raise SettingsError(
f'subcommand argument {model.__name__}.{resolved_name} is not derived from BaseModel'
f'subcommand argument {model.__name__}.{field_name} is not derived from BaseModel'
)
subcommand_args.append((field_name, resolved_name, field_info))
subcommand_args.append((field_name, field_info))
elif _CliPositionalArg in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
positional_args.append((field_name, resolved_name, field_info))
elif any((field_info.alias, field_info.validation_alias)):
raise SettingsError(f'positional argument {model.__name__}.{field_name} has an alias')
positional_args.append((field_name, field_info))
else:
optional_args.append((field_name, resolved_name, field_info))
optional_args.append((field_name, field_info))
return positional_args + subcommand_args + optional_args

@property
Expand Down Expand Up @@ -1251,6 +1288,7 @@ def _connect_root_parser(
arg_prefix=self.env_prefix,
subcommand_prefix=self.env_prefix,
group=None,
alias_prefixes=[],
)

def _add_parser_args(
Expand All @@ -1261,18 +1299,20 @@ def _add_parser_args(
arg_prefix: str,
subcommand_prefix: str,
group: Any,
alias_prefixes: list[str],
) -> ArgumentParser:
subparsers: Any = None
for field_name, resolved_name, field_info in self._sort_arg_fields(model):
alias_path_args: dict[str, str] = {}
for field_name, field_info in self._sort_arg_fields(model):
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
if _CliSubCommand in field_info.metadata:
if subparsers is None:
subparsers = self._add_subparsers(
parser, title='subcommands', dest=f'{arg_prefix}:subcommand', required=self.cli_enforce_required
)
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{resolved_name}']
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}']
else:
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{resolved_name}')
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}')
if hasattr(subparsers, 'metavar'):
metavar = ','.join(self._cli_subcommands[f'{arg_prefix}:subcommand'])
subparsers.metavar = f'{{{metavar}}}'
Expand All @@ -1281,23 +1321,25 @@ def _add_parser_args(
self._add_parser_args(
parser=self._add_parser(
subparsers,
resolved_name,
field_name,
help=field_info.description,
formatter_class=self._formatter_class,
description=model.__doc__,
),
model=model,
added_args=[],
arg_prefix=f'{arg_prefix}{resolved_name}.',
subcommand_prefix=f'{subcommand_prefix}{resolved_name}.',
arg_prefix=f'{arg_prefix}{field_name}.',
subcommand_prefix=f'{subcommand_prefix}{field_name}.',
group=None,
alias_prefixes=[],
)
else:
resolved_names, is_alias_path_only = self._get_resolved_names(field_name, field_info, alias_path_args)
arg_flag: str = '--'
kwargs: dict[str, Any] = {}
kwargs['default'] = SUPPRESS
kwargs['help'] = self._help_format(field_info)
kwargs['dest'] = f'{arg_prefix}{resolved_name}'
kwargs['dest'] = f'{arg_prefix}{resolved_names[0]}'
kwargs['metavar'] = self._metavar_format(field_info.annotation)
kwargs['required'] = self.cli_enforce_required and field_info.is_required()
if kwargs['dest'] in added_args:
Expand All @@ -1309,51 +1351,126 @@ def _add_parser_args(
if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_strip_annotated=True):
self._cli_dict_args[kwargs['dest']] = field_info.annotation

arg_name = (
f'{arg_prefix}{resolved_name}'
if subcommand_prefix == self.env_prefix
else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{resolved_name}'
)
arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, resolved_names)
if _CliPositionalArg in field_info.metadata:
kwargs['metavar'] = resolved_name.upper()
arg_name = kwargs['dest']
kwargs['metavar'] = resolved_names[0].upper()
arg_names = [kwargs['dest']]
del kwargs['dest']
del kwargs['required']
arg_flag = ''

if sub_models and kwargs.get('action') != 'append':
model_group: Any = None
model_group_kwargs: dict[str, Any] = {}
model_group_kwargs['title'] = f'{arg_name} options'
model_group_kwargs['description'] = (
sub_models[0].__doc__
if self.cli_use_class_docs_for_groups and len(sub_models) == 1
else field_info.description
self._add_parser_submodels(
parser,
sub_models,
added_args,
arg_prefix,
subcommand_prefix,
arg_flag,
arg_names,
kwargs,
field_info,
resolved_names,
)
if not self.cli_avoid_json:
added_args.append(arg_name)
kwargs['help'] = f'set {arg_name} from JSON string'
model_group = self._add_argument_group(parser, **model_group_kwargs)
self._add_argument(model_group, f'{arg_flag}{arg_name}', **kwargs)
for model in sub_models:
self._add_parser_args(
parser=parser,
model=model,
added_args=added_args,
arg_prefix=f'{arg_prefix}{resolved_name}.',
subcommand_prefix=subcommand_prefix,
group=model_group if model_group else model_group_kwargs,
)
elif is_alias_path_only:
continue
elif group is not None:
if isinstance(group, dict):
group = self._add_argument_group(parser, **group)
added_args.append(arg_name)
self._add_argument(group, f'{arg_flag}{arg_name}', **kwargs)
added_args += list(arg_names)
self._add_argument(group, *(f'{arg_flag}{name}' for name in arg_names), **kwargs)
else:
added_args.append(arg_name)
self._add_argument(parser, f'{arg_flag}{arg_name}', **kwargs)
added_args += list(arg_names)
self._add_argument(parser, *(f'{arg_flag}{name}' for name in arg_names), **kwargs)

self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
return parser

def _get_arg_names(
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], resolved_names: tuple[str, ...]
) -> list[str]:
arg_names: list[str] = []
for prefix in [arg_prefix] + alias_prefixes:
for name in resolved_names:
arg_names.append(
f'{prefix}{name}'
if subcommand_prefix == self.env_prefix
else f'{prefix.replace(subcommand_prefix, "", 1)}{name}'
)
return arg_names

def _add_parser_submodels(
self,
parser: Any,
sub_models: list[type[BaseModel]],
added_args: list[str],
arg_prefix: str,
subcommand_prefix: str,
arg_flag: str,
arg_names: list[str],
kwargs: dict[str, Any],
field_info: FieldInfo,
resolved_names: tuple[str, ...],
) -> None:
model_group: Any = None
model_group_kwargs: dict[str, Any] = {}
model_group_kwargs['title'] = f'{arg_names[0]} options'
model_group_kwargs['description'] = (
sub_models[0].__doc__
if self.cli_use_class_docs_for_groups and len(sub_models) == 1
else field_info.description
)
if not self.cli_avoid_json:
added_args.append(arg_names[0])
kwargs['help'] = f'set {arg_names[0]} from JSON string'
model_group = self._add_argument_group(parser, **model_group_kwargs)
self._add_argument(model_group, *(f'{arg_flag}{name}' for name in arg_names), **kwargs)
for model in sub_models:
self._add_parser_args(
parser=parser,
model=model,
added_args=added_args,
arg_prefix=f'{arg_prefix}{resolved_names[0]}.',
subcommand_prefix=subcommand_prefix,
group=model_group if model_group else model_group_kwargs,
alias_prefixes=[f'{arg_prefix}{name}.' for name in resolved_names[1:]],
)

def _add_parser_alias_paths(
self,
parser: Any,
alias_path_args: dict[str, str],
added_args: list[str],
arg_prefix: str,
subcommand_prefix: str,
group: Any,
) -> None:
if alias_path_args:
context = parser
if group is not None:
context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group
is_nested_alias_path = arg_prefix.endswith('.')
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
for name, metavar in alias_path_args.items():
name = '' if is_nested_alias_path else name
arg_name = (
f'{arg_prefix}{name}'
if subcommand_prefix == self.env_prefix
else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{name}'
)
kwargs: dict[str, Any] = {}
kwargs['default'] = SUPPRESS
kwargs['help'] = 'pydantic alias path'
kwargs['dest'] = f'{arg_prefix}{name}'
if metavar == 'dict' or is_nested_alias_path:
kwargs['metavar'] = 'dict'
else:
kwargs['action'] = 'append'
kwargs['metavar'] = 'list'
if arg_name not in added_args:
added_args.append(arg_name)
self._add_argument(context, f'--{arg_name}', **kwargs)

def _get_modified_args(self, obj: Any) -> tuple[str, ...]:
if not self.cli_hide_none_type:
return get_args(obj)
Expand Down
Loading

0 comments on commit bd294a4

Please sign in to comment.