Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

manual_backward and .backward() have different behaviour. #18740

Open
roedoejet opened this issue Oct 6, 2023 · 4 comments
Open

manual_backward and .backward() have different behaviour. #18740

roedoejet opened this issue Oct 6, 2023 · 4 comments
Labels
bug Something isn't working repro needed The issue is missing a reproducible example ver: 2.0.x

Comments

@roedoejet
Copy link

roedoejet commented Oct 6, 2023

Bug description

I expected manual_backward and .backward to perform backward propagation in the same way, but when I use self.manual_backward it results in a number of unused parameters. If I use .backward then the problem doesn't occur.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

def training_step(self, batch, batch_idx):
        x, y, _, y_mel = batch
        y = y.unsqueeze(1)
        # x.size() & y_mel.size() = [batch_size, n_mels=80, n_frames=32]
        # y.size() = [batch_size, segment_size=8192]
        optim_g, optim_d = self.optimizers()
        scheduler_g, scheduler_d = self.lr_schedulers()
        # generate waveform
        if self.config.model.istft_layer:
            mag, phase = self(x)
            generated_wav = self.inverse_spectral_transform(
                mag * torch.exp(phase * 1j)
            ).unsqueeze(-2)
        else:
            generated_wav = self(x)
       
        # create mel
        generated_mel_spec = dynamic_range_compression_torch(
            self.spectral_transform(generated_wav).squeeze(1)[:, :, 1:]
        )
        # train discriminators
        optim_d.zero_grad()
        # MPD
        y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, generated_wav.detach())
        if self.use_gradient_penalty:
            gp_f = self.compute_gradient_penalty(y.data, generated_wav.detach().data, self.mpd)
        else:
            gp_f = None
        loss_disc_f, _, _ = self.discriminator_loss(y_df_hat_r, y_df_hat_g, gp=gp_f)
        self.log("training/disc/mpd_loss", loss_disc_f, prog_bar=False)
        # MSD
        y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, generated_wav.detach())
        loss_disc_s, _, _ = self.discriminator_loss(y_ds_hat_r, y_ds_hat_g, gp=gp_s)
        self.log("training/disc/msd_loss", loss_disc_s, prog_bar=False)
        # calculate loss
        disc_loss_total = loss_disc_s + loss_disc_f
        # manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
        # this works
        disc_loss_total.backward()
        # this does not
        # self.manual_backward(disc_loss_total
        optim_d.step()
        scheduler_d.step()
        # log discriminator loss
        self.log("training/disc/d_loss_total", disc_loss_total, prog_bar=False)
            
        # train generator
        optim_g.zero_grad()
        # calculate loss
        _, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, generated_wav)
        _, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, generated_wav)
        loss_fm_f = self.feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = self.feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, _ = self.generator_loss(
            y_df_hat_g, gp=self.use_gradient_penalty
        )
        loss_gen_s, _ = self.generator_loss(
            y_ds_hat_g, gp=self.use_gradient_penalty
        )
        self.log("training/gen/loss_fmap_f", loss_fm_f, prog_bar=False)
        self.log("training/gen/loss_fmap_s", loss_fm_s, prog_bar=False)
        self.log("training/gen/loss_gen_f", loss_gen_f, prog_bar=False)
        self.log("training/gen/loss_gen_s", loss_gen_s, prog_bar=False)
        loss_mel = F.l1_loss(y_mel, generated_mel_spec) * 45
        gen_loss_total = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
        # manual optimization because Pytorch Lightning 2.0+ doesn't handle automatic optimization for multiple optimizers
        gen_loss_total.backward()
        optim_g.step()
        scheduler_g.step()
        # log generator loss
        self.log("training/gen/gen_loss_total", gen_loss_total, prog_bar=True)
        self.log("training/gen/mel_spec_error", loss_mel / 45, prog_bar=False)

I caught this by adding an on_after_backward method. When I use self.manual_backward(disc_loss_total) or self.manual_backward(gen_loss_total) then I get a bunch of parameters with p.grad == None but when I use disc_loss_total.backward() everything works fine.

Error messages and logs

# Error messages and logs here please

Environment

Current environment
  • CUDA:
    • GPU:
      • Tesla V100-SXM2-16GB
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.0.4
    • lightning-cloud: 0.5.39
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.0.9.post0
    • torch: 2.0.1+cu117
    • torchaudio: 2.0.2+cu117
    • torchmetrics: 1.2.0
  • Packages:
    • absl-py: 2.0.0
    • aiohttp: 3.8.5
    • aiosignal: 1.3.1
    • aniso8601: 9.0.1
    • annotated-types: 0.5.0
    • anyio: 3.7.1
    • anytree: 2.9.0
    • arrow: 1.3.0
    • async-timeout: 4.0.3
    • attrs: 23.1.0
    • audioread: 3.0.1
    • beautifulsoup4: 4.12.2
    • bidict: 0.22.1
    • black: 22.12.0
    • blessed: 1.20.0
    • cachetools: 5.3.1
    • certifi: 2023.7.22
    • cffi: 1.16.0
    • cfgv: 3.4.0
    • charset-normalizer: 3.3.0
    • click: 8.1.7
    • clipdetect: 0.1.3
    • cmake: 3.27.6
    • colorama: 0.4.6
    • coloredlogs: 14.0
    • contourpy: 1.1.1
    • croniter: 1.3.15
    • cycler: 0.12.0
    • cython: 3.0.3
    • dateutils: 0.6.12
    • decorator: 5.1.1
    • deepdiff: 6.6.0
    • distlib: 0.3.7
    • dnspython: 2.3.0
    • editdistance: 0.6.2
    • einops: 0.5.0
    • et-xmlfile: 1.1.0
    • eventlet: 0.33.3
    • everyvoice: 0.1.20231005
    • exceptiongroup: 1.1.3
    • fastapi: 0.103.2
    • filelock: 3.12.4
    • flake8: 6.1.0
    • flask: 2.2.5
    • flask-cors: 4.0.0
    • flask-restful: 0.3.10
    • flask-socketio: 5.3.6
    • flask-talisman: 1.1.0
    • fonttools: 4.43.0
    • frozenlist: 1.4.0
    • fsspec: 2023.9.2
    • g2p: 1.1.20230822
    • gitlint-core: 0.19.1
    • google-auth: 2.23.2
    • google-auth-oauthlib: 1.0.0
    • greenlet: 3.0.0
    • grpcio: 1.59.0
    • h11: 0.14.0
    • humanfriendly: 10.0
    • identify: 2.5.30
    • idna: 3.4
    • importlib-metadata: 6.8.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • isort: 5.12.0
    • itsdangerous: 2.1.2
    • jinja2: 3.1.2
    • joblib: 1.3.2
    • jsonschema: 4.19.1
    • jsonschema-specifications: 2023.7.1
    • kiwisolver: 1.4.5
    • librosa: 0.9.2
    • lightning: 2.0.4
    • lightning-cloud: 0.5.39
    • lightning-utilities: 0.9.0
    • lit: 17.0.2
    • llvmlite: 0.41.0
    • loguru: 0.6.0
    • markdown: 3.4.4
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.6.0
    • mccabe: 0.7.0
    • mdurl: 0.1.2
    • merge-args: 0.1.5
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • munkres: 1.1.4
    • mypy: 1.5.1
    • mypy-extensions: 1.0.0
    • networkx: 2.8.4
    • nltk: 3.7
    • nodeenv: 1.8.0
    • numba: 0.58.0
    • numpy: 1.25.2
    • oauthlib: 3.2.2
    • openpyxl: 3.1.2
    • ordered-set: 4.1.0
    • packaging: 23.2
    • pandas: 1.4.4
    • panphon: 0.20.0
    • pathspec: 0.11.2
    • pillow: 10.0.1
    • pip: 23.2.1
    • platformdirs: 3.11.0
    • pluggy: 1.3.0
    • pooch: 1.7.0
    • pre-commit: 3.4.0
    • prompt-toolkit: 3.0.39
    • protobuf: 4.24.4
    • psutil: 5.9.5
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycodestyle: 2.11.0
    • pycountry: 22.3.5
    • pycparser: 2.21
    • pydantic: 2.4.2
    • pydantic-core: 2.10.1
    • pyflakes: 3.1.0
    • pygments: 2.16.1
    • pyjwt: 2.8.0
    • pympi-ling: 1.70.2
    • pyparsing: 3.1.1
    • pysdtw: 0.0.5
    • pytest: 7.4.2
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-engineio: 4.7.1
    • python-multipart: 0.0.6
    • python-socketio: 5.9.0
    • pytorch-lightning: 2.0.9.post0
    • pytz: 2023.3.post1
    • pyworld: 0.3.4
    • pyyaml: 6.0.1
    • questionary: 1.10.0
    • readchar: 4.0.5
    • referencing: 0.30.2
    • regex: 2023.10.3
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • resampy: 0.4.2
    • rich: 13.6.0
    • rpds-py: 0.10.4
    • rsa: 4.9
    • scikit-learn: 1.3.1
    • scipy: 1.11.3
    • setuptools: 59.5.0
    • sh: 2.0.6
    • shellingham: 1.5.3
    • simple-term-menu: 1.5.2
    • simple-websocket: 1.0.0
    • six: 1.16.0
    • sniffio: 1.3.0
    • soundfile: 0.12.1
    • soupsieve: 2.5
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • tabulate: 0.8.10
    • tensorboard: 2.14.1
    • tensorboard-data-server: 0.7.1
    • text-unidecode: 1.3
    • threadpoolctl: 3.2.0
    • tomli: 2.0.1
    • torch: 2.0.1+cu117
    • torchaudio: 2.0.2+cu117
    • torchmetrics: 1.2.0
    • tqdm: 4.66.1
    • traitlets: 5.11.2
    • triton: 2.0.0
    • typer: 0.9.0
    • types-python-dateutil: 2.8.19.14
    • types-pyyaml: 6.0.12.12
    • types-requests: 2.31.0.8
    • types-setuptools: 68.2.0.0
    • types-tabulate: 0.8.11
    • typing-extensions: 4.8.0
    • unicodecsv: 0.14.1
    • urllib3: 2.0.6
    • uvicorn: 0.23.2
    • virtualenv: 20.24.5
    • wcwidth: 0.2.8
    • websocket-client: 1.6.3
    • websockets: 11.0.3
    • werkzeug: 2.2.3
    • wheel: 0.41.2
    • wsproto: 1.2.0
    • yarl: 1.9.2
    • zipp: 3.17.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.9.18
    • release: 4.15.0-204-generic
    • version: Demos #215-Ubuntu SMP Fri Jan 20 18:24:59 UTC 2023

More info

No response

@roedoejet roedoejet added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 6, 2023
@awaelchli awaelchli added repro needed The issue is missing a reproducible example and removed needs triage Waiting to be triaged by maintainers labels Oct 6, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Oct 6, 2023

@roedoejet
Copy link
Author

roedoejet commented Oct 6, 2023

Hey @roedoejet
Is this reproducible in our GAN examples?
https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/domain_templates/generative_adversarial_net.py

I'm away now until next week, but I will give it a shot then and post and update here. Thanks.

@roedoejet
Copy link
Author

Unfortunately this is not reproducible in the above-posted GAN example in my environment. I will try to poke around a bit more to see if I can find a minimal example.

@awaelchli
Copy link
Contributor

@roedoejet Thanks for looking at it. Due to priorities, I won't have the bandwidth to search for the bug. If you find a way to reproduce this in a code example we can study, that would be invaluable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working repro needed The issue is missing a reproducible example ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

2 participants