Skip to content

Commit

Permalink
- Fixed override of dict_kwargs items from command line not working.
Browse files Browse the repository at this point in the history
- Fixed multiple subclass init_args given through command line not being considered pytorch-lightning#15007.
  • Loading branch information
mauvilsa committed Oct 7, 2022
1 parent 9f0d081 commit 3337a0e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ Fixed
``init_args`` when loading from config.
- Subclass ``--*.help`` option not available when type is a ``Union`` mixed with
not subclass types.
- Override of ``dict_kwargs`` items from command line not working.
- Multiple subclass ``init_args`` given through command line not being
considered `pytorch-lightning#15007
<https://github.com/PyTorchLightning/pytorch-lightning/pull/15007>`__.


v4.15.0 (2022-09-27)
Expand Down
30 changes: 16 additions & 14 deletions jsonargparse/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,10 @@ def __call__(self, *args, **kwargs):
if self.nargs == '?' and args[2] is None:
val = None
else:
parser, cfg, val, opt_str = args
cfg, val, opt_str = args[1:]
if isinstance(opt_str, str) and opt_str.startswith(f'--{self.dest}.'):
sub_opt = opt_str[len(f'--{self.dest}.'):]
val = NestedArg(key=sub_opt, val=val)
if self.dest not in cfg:
try:
default = parser.get_default(self.dest)
cfg = deepcopy(cfg)
cfg[self.dest] = default
except KeyError:
pass
append = opt_str == f'--{self.dest}+'
val = self._check_type(val, append=append, cfg=cfg)
args[1].update(val, self.dest)
Expand All @@ -356,9 +349,14 @@ def _check_type(self, value, append=False, cfg=None):
except get_loader_exceptions():
config_path = None
path_meta = val.pop('__path__', None) if isinstance(val, dict) else None

prev_val = cfg.get(self.dest) if cfg else None
if not prev_val and not sub_defaults.get() and is_subclass_spec(self.default):
prev_val = Namespace(class_path=self.default.class_path)

kwargs = {
'sub_add_kwargs': getattr(self, 'sub_add_kwargs', {}),
'prev_val': cfg.get(self.dest) if cfg else None,
'prev_val': prev_val,
'append': append,
}
try:
Expand Down Expand Up @@ -667,6 +665,8 @@ def is_subclass_spec(val):


def subclass_spec_as_namespace(val, prev_val=None):
if val is None:
return None
if isinstance(val, str):
return Namespace(class_path=val)
if isinstance(val, NestedArg):
Expand Down Expand Up @@ -767,7 +767,6 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
if isinstance(parser_or_action, ActionTypeHint):
sub_add_kwargs = getattr(parser_or_action, 'sub_add_kwargs', {})
parser = ActionTypeHint.get_class_parser(value['class_path'], sub_add_kwargs)
prev_val = subclass_spec_as_namespace(prev_val)
del_args = {}
for key, val in list(prev_val.init_args.__dict__.items()):
action = _find_action(parser, key)
Expand All @@ -787,6 +786,7 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):


def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None):
prev_val = subclass_spec_as_namespace(prev_val)
value = subclass_spec_as_namespace(value)
val_class = import_object(value.class_path)
parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs)
Expand Down Expand Up @@ -820,10 +820,12 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
return value
return val_class(**{**init_args, **dict_kwargs})

prev_init_args = prev_val.get('init_args') if prev_val else None

if isinstance(init_args, NestedArg):
value['init_args'] = parser.parse_args(
[f'--{init_args.key}={init_args.val}'],
namespace=prev_val.init_args.clone(),
namespace=prev_init_args,
defaults=sub_defaults.get(),
)
return value
Expand All @@ -839,12 +841,12 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
elif dict_kwargs:
init_args['dict_kwargs'] = dict_kwargs
dict_kwargs = None
init_args = parser.parse_object(init_args, defaults=sub_defaults.get())
init_args = parser.parse_object(init_args, cfg_base=prev_init_args, defaults=sub_defaults.get())
if init_args:
value['init_args'] = init_args
if dict_kwargs:
if isinstance(prev_val, Namespace) and prev_val.get('class_path') == value['class_path'] and prev_val.get('dict_kwargs'):
dict_kwargs.update(prev_val.get('dict_kwargs'))
if prev_val and prev_val.get('class_path') == value['class_path'] and prev_val.get('dict_kwargs'):
dict_kwargs = {**prev_val.get('dict_kwargs'), **dict_kwargs}
value['dict_kwargs'] = dict_kwargs
return value

Expand Down
17 changes: 16 additions & 1 deletion jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,21 @@ def test_list_append_config(self):
self.assertRaises(ParserError, lambda: parser.parse_args(['--cfg', 'val+: a']))


def test_list_append_subclass_init_args(self):
class Class:
def __init__(self, p1: int = 0, p2: int = 0):
pass

parser = ArgumentParser(error_handler=None)
parser.add_argument('--val', type=Union[Class, List[Class]])

with mock_module(Class) as module:
cfg = parser.parse_args([f'--val+={module}.Class', '--val.p1=1', '--val.p2=2', '--val.p1=3'])
self.assertEqual(cfg.val, [Namespace(class_path=f'{module}.Class', init_args=Namespace(p1=3, p2=2))])
cfg = parser.parse_args([f'--val+=Class', '--val.p2=2', '--val.p1=1'])
self.assertEqual(cfg.val, [Namespace(class_path=f'{module}.Class', init_args=Namespace(p1=1, p2=2))])


def test_list_append_subcommand_subclass(self):
class A:
def __init__(self, cals: Union[Calendar, List[Calendar]] = None):
Expand Down Expand Up @@ -792,7 +807,7 @@ def __init__(self, p1: int = 1, p2: str = '2', **kwargs):
self.assertIsInstance(cfg_init.cls, Class)
self.assertEqual(cfg_init.cls.kwargs, expected.dict_kwargs)

cfg = parser.parse_args(['--cls=Class', '--cls.dict_kwargs.p4=x', '--cls.dict_kwargs.p3=7.0'])
cfg = parser.parse_args(['--cls=Class', '--cls.dict_kwargs.p4=-', '--cls.dict_kwargs.p3=7.0', '--cls.dict_kwargs.p4=x'])
self.assertEqual(cfg.cls.dict_kwargs, expected.dict_kwargs)

with self.assertRaises(ParserError):
Expand Down

0 comments on commit 3337a0e

Please sign in to comment.