Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

discard_init_args_on_class_path_change does not handle nested dict subclass specs in some contexts #247

Closed
speediedan opened this issue Feb 23, 2023 · 0 comments · Fixed by #248
Labels
bug Something isn't working

Comments

@speediedan
Copy link
Contributor

🐛 Bug report

Firstly, thanks again for all your work maintaining this remarkably valuable package!

While a subclass_spec can be either a dict or Namespace

def is_subclass_spec(val):
is_class = isinstance(val, (dict, Namespace)) and 'class_path' in val
if is_class:
keys = getattr(val, '__dict__', val).keys()
is_class = len(set(keys)-{'class_path', 'init_args', 'dict_kwargs', '__path__'}) == 0
return is_class

when entering ActionTypeHint.discard_init_args_on_class_path_change:
@staticmethod
def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg):
if isinstance(prev_cfg, dict):
return
keys = list(prev_cfg.keys(branches=True))
num = 0
while num < len(keys):
key = keys[num]
prev_val = prev_cfg.get(key)
val = cfg.get(key)
if is_subclass_spec(prev_val) and is_subclass_spec(val):
action = parser_or_action
if not isinstance(parser_or_action, ActionTypeHint):
action = _find_action(parser_or_action, key)
if isinstance(action, ActionTypeHint):
discard_init_args_on_class_path_change(action, prev_val, val)

jsonargparse.typehints.discard_init_args_on_class_path_change below currently assumes any previous value for a nested subclass will already have been cast as a Namespace:

def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
if prev_val and 'init_args' in prev_val and prev_val['class_path'] != value['class_path']:
parser = parser_or_action
if isinstance(parser_or_action, ActionTypeHint):
sub_add_kwargs = getattr(parser_or_action, 'sub_add_kwargs', {})
parser = ActionTypeHint.get_class_parser(value['class_path'], sub_add_kwargs)
del_args = {}
for key, val in list(prev_val.init_args.__dict__.items()):
action = _find_action(parser, key)

with L941 resulting in:

argparse.ArgumentError: Parser key "some_class":
  Problem with given class_path '__main__.CustomModule':
    - 'dict' object has no attribute 'init_args'

Right now, nested generic dict configurations like the repro case I include below don't trigger the subclass_spec_as_namespace Namespace wrapping that jsonargparse.typehints.discard_init_args_on_class_path_change requires when removing overridden subclass spec configuration in this context.

With the below minor patch to jsonargparse.typehints.discard_init_args_on_class_path_change, this use case can be accommodated (I haven't had a chance to run the full jsongarparse test suite but can't immediately think of any problem with this additional check/transformation):

def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
    if prev_val and 'init_args' in prev_val and prev_val['class_path'] != value['class_path']:
        parser = parser_or_action
        if isinstance(parser_or_action, ActionTypeHint):
            sub_add_kwargs = getattr(parser_or_action, 'sub_add_kwargs', {})
            parser = ActionTypeHint.get_class_parser(value['class_path'], sub_add_kwargs)
        del_args = {}
        if isinstance(prev_val, dict):
            prev_val = subclass_spec_as_namespace(prev_val)
        for key, val in list(prev_val.init_args.__dict__.items()):
            action = _find_action(parser, key)

To reproduce

from typing import Any, Dict

import yaml
import lightning.pytorch as pl
from jsonargparse import ArgumentParser


override_config = yaml.safe_dump(
    {
      "class_path": "CustomModule",
      "init_args": {
        "some_custom_init": {
          "class_path": "torch.optim.lr_scheduler.StepLR",
          "init_args": {"step_size": 1, "gamma": 0.7},
        }
      }
    }
  )

default_config = yaml.safe_dump(
  {
    "class_path": "CustomModule",
    "init_args": {
      "some_custom_init": {
        "class_path": "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts",
        "init_args": {"T_0": 1, "T_mult": 2, "eta_min": 1.0e-07}
      }
    }
  }
)


class CustomModule(pl.LightningModule):
    def __init__(self, some_custom_init: Dict[str, Any],):
        super().__init__()


parser = ArgumentParser()
parser.add_argument("--some_class", type=pl.LightningModule)
result = parser.parse_args(["--some_class", default_config, "--some_class", override_config])
print(f"expected_result= {result}")

Current behavior:

argparse.ArgumentError: Parser key "some_class":
  Problem with given class_path '__main__.CustomModule':
    - 'dict' object has no attribute 'init_args'

Expected result, generated with additional handling discussed above:

expected_result= Namespace(some_class=Namespace(class_path='__main__.CustomModule', init_args=Namespace(some_custom_init={'class_path': 'torch.optim.lr_scheduler.StepLR', 'init_args': {'gamma': 0.7, 'step_size': 1}})))

Environment

  • jsonargparse version (e.g., 4.8.0): 4.20.0
  • Python version (e.g., 3.9): 3.10
  • How jsonargparse was installed (e.g. pip install jsonargparse[all]): pip install "jsonargparse[signatures]==4.20.0"
  • OS (e.g., Linux): Ubuntu 20.04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant