Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argument linking fails when setting model and data via command line instead of passing it to the CLI. #16032

Closed
tobemo opened this issue Dec 13, 2022 · 3 comments · Fixed by omni-us/jsonargparse#218
Labels
3rd party Related to a 3rd-party bug Something isn't working lightningcli pl.cli.LightningCLI

Comments

@tobemo
Copy link

tobemo commented Dec 13, 2022

Bug description

Argument linking when passing model and datamodule to the cli constructor works; like so: MyLightningCLI(MyModel, MyData).

Argument linking when setting model and datamodule via the command line interface using flags like --model, --data and/or --config raises a ValueError saying Target key "model.foo" must be for an individual argument.
See code below.

How to reproduce the bug

from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import LightningModule, LightningDataModule


activations = {
    'MaxAbsScaler': 'Sigmoid',
    'Normalizer': 'Sigmoid',
    'QuantileTransformer': 'Sigmoid',
    'RobustScaler': 'PReLU',
    'StandardScaler': 'PReLU',
    'PowerTransformer': 'TanhShrink',
}


class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.scaler", "model.scaler",
            compute_fn=lambda scaler: activations.get(scaler, 'linear'),
            apply_on='instantiate'
        )


class MyModel(LightningModule):
    def __init__(self, activation: str) -> None:
        super().__init__()
        self.save_hyperparameters()
        print(activation)


class MyData(LightningDataModule):
    def __init__(self, scaler: str) -> None:
        super().__init__()
        self.save_hyperparameters()
        print(scaler)


def main():
    # this works :)
    # cli = MyLightningCLI(MyModel, MyData)
    # > python fit .\debug.py --data.scaler MaxAbsScaler
    # > prints Sigmoid
    
    # this does not work :(
    cli = MyLightningCLI()
    # > python .\debug.py fit --data MyData --data.scaler MaxAbsScaler --model MyModel
    # > ValueError: Target key "model.activation" must be for an individual argument.

if __name__ == '__main__':
    main()

Error messages and logs

Traceback (most recent call last):
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\debug.py", line 50, in <module>
    main()
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\debug.py", line 45, in main
    cli = MyLightningCLI()
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 343, in __init__
    self.setup_parser(run, main_kwargs, subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 403, in setup_parser
    self._add_subcommands(self.parser, **subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 480, in _add_subcommands
    subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **subparser_kwargs)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 486, in _prepare_subcommand_parser
    self._add_arguments(parser)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\pytorch_lightning\cli.py", line 439, in _add_arguments
    self.add_arguments_to_parser(parser)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\debug.py", line 17, in add_arguments_to_parser
    parser.link_arguments(
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\jsonargparse\link_arguments.py", line 379, in link_arguments
    ActionLink(self, source, target, compute_fn, apply_on)
  File "D:\Users\MC\Documents\ALGO_TRADING\0 BIAS\VAE\.venv\lib\site-packages\jsonargparse\link_arguments.py", line 141, in __init__
    raise ValueError(f'Target key "{target}" must be for an individual argument.')
ValueError: Target key "model.scaler" must be for an individual argument.

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA GeForce GTX 970
        - available:         True
        - version:           11.7
* Lightning:
        - lightning-utilities: 0.3.0
        - pytorch-lightning: 1.8.3.post1
        - torch:             1.13.0+cu117
        - torchaudio:        0.13.0+cu117
        - torchmetrics:      0.11.0
        - torchvision:       0.14.0+cu117
* Packages:
        - absl-py:           1.3.0
        - aiohttp:           3.8.3
        - aiosignal:         1.3.1
        - alembic:           1.8.1
        - antlr4-python3-runtime: 4.9.3
        - asttokens:         2.1.0
        - async-timeout:     4.0.2
        - attrs:             22.1.0
        - autopage:          0.5.1
        - backcall:          0.2.0
        - cachetools:        5.2.0
        - certifi:           2022.9.24
        - charset-normalizer: 2.1.1
        - cliff:             4.1.0
        - cmaes:             0.9.0
        - cmd2:              2.4.2
        - colorama:          0.4.6
        - colorlog:          6.7.0
        - commonmark:        0.9.1
        - contourpy:         1.0.6
        - cycler:            0.11.0
        - debugpy:           1.6.3
        - decorator:         5.1.1
        - docstring-parser:  0.15
        - entrypoints:       0.4
        - executing:         1.2.0
        - fire:              0.4.0
        - fonttools:         4.38.0
        - frozenlist:        1.3.3
        - fsspec:            2022.11.0
        - google-auth:       2.15.0
        - google-auth-oauthlib: 0.4.6
        - greenlet:          2.0.1
        - grpcio:            1.51.1
        - hydra-core:        1.2.0
        - idna:              3.4
        - imageio:           2.22.4
        - importlib-metadata: 4.13.0
        - ipykernel:         6.17.1
        - ipython:           8.6.0
        - jedi:              0.18.1
        - joblib:            1.2.0
        - jsonargparse:      4.18.0
        - jupyter-client:    7.4.6
        - jupyter-core:      5.0.0
        - kiwisolver:        1.4.4
        - lightning-utilities: 0.3.0
        - mako:              1.2.4
        - markdown:          3.4.1
        - markupsafe:        2.1.1
        - matplotlib:        3.6.2
        - matplotlib-inline: 0.1.6
        - multidict:         6.0.2
        - nest-asyncio:      1.5.6
        - networkx:          2.8.8
        - numpy:             1.23.4
        - oauthlib:          3.2.2
        - omegaconf:         2.2.3
        - optuna:            3.0.4
        - packaging:         21.3
        - pandas:            1.5.1
        - parso:             0.8.3
        - pbr:               5.11.0
        - pickleshare:       0.7.5
        - pillow:            9.3.0
        - pip:               22.3.1
        - platformdirs:      2.5.4
        - prettytable:       3.5.0
        - prompt-toolkit:    3.0.32
        - protobuf:          3.20.1
        - psutil:            5.9.4
        - pure-eval:         0.2.2
        - pyarrow:           10.0.0
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pygments:          2.13.0
        - pyparsing:         3.0.9
        - pyperclip:         1.8.2
        - pyreadline3:       3.4.1
        - python-dateutil:   2.8.2
        - pytorch-lightning: 1.8.3.post1
        - pytz:              2022.6
        - pywavelets:        1.4.1
        - pywin32:           305
        - pyyaml:            6.0
        - pyzmq:             24.0.1
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - rich:              12.6.0
        - rsa:               4.9
        - scikit-image:      0.19.3
        - scikit-learn:      1.1.3
        - scipy:             1.8.1
        - setuptools:        58.1.0
        - six:               1.16.0
        - sqlalchemy:        1.4.44
        - stack-data:        0.6.1
        - stevedore:         4.1.1
        - tensorboard:       2.11.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - tensorboardx:      2.5.1
        - termcolor:         2.1.1
        - threadpoolctl:     3.1.0
        - tifffile:          2022.10.10
        - torch:             1.13.0+cu117
        - torchaudio:        0.13.0+cu117
        - torchmetrics:      0.11.0
        - torchvision:       0.14.0+cu117
        - tornado:           6.2
        - tqdm:              4.64.1
        - traitlets:         5.5.0
        - typing-extensions: 4.4.0
        - ujson:             5.5.0
        - urllib3:           1.26.12
        - wcwidth:           0.2.5
        - werkzeug:          2.2.2
        - wheel:             0.38.4
        - yarl:              1.8.1
        - zipp:              3.11.0
* System:
        - OS:                Windows
        - architecture:
                - 64bit
                - WindowsPE
        - processor:         AMD64 Family 23 Model 8 Stepping 2, AuthenticAMD
        - python:            3.10.4
        - version:           10.0.19045

More info

No response

cc @carmocca @mauvilsa

@tobemo tobemo added the needs triage Waiting to be triaged by maintainers label Dec 13, 2022
@samvanstroud
Copy link

See also omni-us/jsonargparse#208

@awaelchli awaelchli added bug Something isn't working lightningcli pl.cli.LightningCLI and removed needs triage Waiting to be triaged by maintainers labels Dec 17, 2022
@carmocca carmocca added the 3rd party Related to a 3rd-party label Dec 17, 2022
@mauvilsa
Copy link
Contributor

In pull request omni-us/jsonargparse#218 this error message has been changed to make it more clear what the problem is. Note that the link_arguments in the code above is a mistake and should give an error. The correct link would be:

        parser.link_arguments(
            "data.scaler", "model.init_args.activation",
            compute_fn=lambda scaler: activations.get(scaler, 'linear'),
            apply_on='instantiate'
        )

@tobemo
Copy link
Author

tobemo commented Dec 19, 2022

Great, this works!
For completeness sake I would like to mention that since data.scaler will already be instantiated by the time it goes through the linking process it is no longer a string but an object.

This works:

parser.link_arguments(
    "data.scaler", "model.init_args.activation",
    compute_fn=lambda scaler: activations.get(type(scaler).__name__, 'linear'),
    apply_on='instantiate'
)

@tobemo tobemo closed this as completed Dec 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working lightningcli pl.cli.LightningCLI
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants