Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Flash DeepSpeedPlugin error #1249

Closed
Moumeneb1 opened this issue Mar 25, 2022 · 9 comments · Fixed by #1377
Closed

Flash DeepSpeedPlugin error #1249

Moumeneb1 opened this issue Mar 25, 2022 · 9 comments · Fixed by #1377
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@Moumeneb1
Copy link

🐛 Bug

Hi, I'm completely new Flash, I'm having this error when trying to use DeepSpeedPlugin plugin on the Flash trainer.
The code is

To Reproduce

import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    "file",
    "text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
    batch_size=4,
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(),plugins='deepspeed')
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")

I'm getting this
Screenshot 2022-03-25 at 16 45 29

I thing there is an issue with named_params somewhere

@Moumeneb1 Moumeneb1 added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 25, 2022
@ethanwharris
Copy link
Collaborator

Hey @Moumeneb1 Thanks for reporting this! We'll look into it and get back to you 😃

@ar90n
Copy link
Contributor

ar90n commented May 13, 2022

Is there anyone to look into it? If no one, may I try to look into it?

@krshrimali
Copy link
Contributor

Is there anyone to look into it? If no one, may I try to look into it?

Hi, @ar90n - Thank you for showing interest in this issue! It would be great if you want to take a look into it, as far as I'm aware - no one is actively working on this issue, so please feel free to pick it up. In case you face any issues, please feel free to reach out to me or @ethanwharris on our PyTorch Lightning Slack workspace.

I've also assigned this issue to you! :)

@ar90n
Copy link
Contributor

ar90n commented May 13, 2022

@krshrimali
Thank for your kindness. I'm going to get to this issue.

@ar90n
Copy link
Contributor

ar90n commented May 16, 2022

Hi, after I looked into this issue, I found the cause of it. It doesn’t depend on lightning-flash. It occurs with only pytorch-lightning. And there is some issue like this in pytorch-ligntning Github project page. Therefore I added some comments on them and create a new issue about integration with pytorch-lightning and DeepSpeed. They are the followings.

I think that if the above issue is solved, this issue will be solved automatically. I continue to try to solve them.

@ar90n
Copy link
Contributor

ar90n commented May 22, 2022

Hi, after my investigation, I found that finetuning using pytorch-lightning can not work with DeepSpeed correctly. Concretely, loading its parameters doesn't work when resuming training.
As I commented in this issue, this is caused by pytorch-lightinig’s lifecycle which is the ordering of calling callbacks. So I think it is difficult to change it by the risk of the compatibility of older versions. But if we can omit the support of resuming training, there are simple workarounds as followings. This doesn’t support resuming but works.

import torch
from functools import partial

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

from flash.core.finetuning import FlashBaseFinetuning, FinetuningStrategies, _FINETUNING_STRATEGIES_REGISTRY


class FlashDeepSpeedFinetuning(FlashBaseFinetuning):
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        from pytorch_lightning.loops.utilities import _get_active_optimizers

        for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0):
            self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)


class NoFreeze(FlashDeepSpeedFinetuning):
    def __init__(self, train_bn: bool = True):
        super().__init__(FinetuningStrategies.NO_FREEZE, train_bn)


_FINETUNING_STRATEGIES_REGISTRY(
    name="no_freeze_deepspeed",
    fn=partial(FlashDeepSpeedFinetuning, strategy_key=FinetuningStrategies.NO_FREEZE),
)


# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    "file",
    "text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
    batch_size=4,
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count(), strategy="deepspeed")
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze_deepspeed")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
torch.save(model.state_dict(), "./speech_recognition_model.pth")

What do you think about this workaround?

@krshrimali
Copy link
Contributor

Hi, @ar90n - I just wanted to thank you for your continuous efforts on this issue! Apologies for not being able to respond on time.

Just to acknowledge that we are aware of this, I'll get back to you in a day or two. :) I appreciate the patience 🚀 and your efforts! ❤️

@ethanwharris
Copy link
Collaborator

This is awesome @ar90n !!! Thanks for digging in to this. We would definitely welcome a PR to add your no_freeze_deepspeed strategy to Flash if that's something you're interested in? Also, do you know if it would be possible to make a freeze strategy to work with deepspeed too? That would unlock finetuning with deepspeed which would be quite cool 😃

@ethanwharris ethanwharris added this to the 0.9.0 milestone Jun 29, 2022
@ar90n
Copy link
Contributor

ar90n commented Jul 2, 2022

@ethanwharris
Thanks for your suggestion! I'm so interested in creating a PR about the no_freeze_deepspeed strategy. And the freeze strategy doesn't work because of the same reason as no_freeze. So this is solved in the same way as the no_freeze strategy. I will create a PR about adding no_freeze_deepspeed and freeze_deepspeed strategies.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants