Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for NeMo scope Optimizers support and add Novograd Optimizer #793

Merged
merged 15 commits into from
Jul 2, 2020

Conversation

titu1994
Copy link
Collaborator

@titu1994 titu1994 commented Jul 1, 2020

Salient points

Refactor optimizer to support changing optimizers from command line. Run the codebase using the below examples

  1. Default to Adam if no optimizer provided via --optimizer.
  2. Create base optimizer if no overriding --opt_args are passed
  3. Unify optimizer support for all domains of NeMo. Standardized interface to add optimizer, lr and opt_args to argparse.

Usage

Use Adam and just override LR

python speech_to_text.py \
        --asr_model "bad_quartznet15x5.yaml" \
        --train_dataset "./an4/train_manifest.json" \
        --eval_dataset "./an4/test_manifest.json" \
        --gpus 4 \
        --distributed_backend "ddp" \
        --max_epochs 1 \
        --fast_dev_run \
        --lr 0.001 

Change optimizer and override LR

python speech_to_text.py \
        --asr_model "bad_quartznet15x5.yaml" \
        --train_dataset "./an4/train_manifest.json" \
        --eval_dataset "./an4/test_manifest.json" \
        --gpus 4 \
        --distributed_backend "ddp" \
        --max_epochs 1 \
        --fast_dev_run \
        --optimizer novograd \
        --lr 0.01 

Change optimizer, override LR, override optimizer args

python speech_to_text.py \
        --asr_model "bad_quartznet15x5.yaml" \
        --train_dataset "./an4/train_manifest.json" \
        --eval_dataset "./an4/test_manifest.json" \
        --gpus 4 \
        --distributed_backend "ddp" \
        --max_epochs 1 \
        --fast_dev_run \
        --optimizer novograd \
        --lr 0.01 \
        --opt_args betas=0.95,0.5 weight_decay=1e-3

Usage - Overriding default args in the model itself

When calling add_optimizer_args(parser), we can pass arguments directly here, thereby overriding the argparsers default values, as shown below. As such, with the same api, different domains can use different optimizers with different arguments.

Base initialization

parser = add_optimizer_args(parser)  # Use adam and empty opt_args list by default

Override optimizer

parser = add_optimizer_args(parser, optimizer='novograd')  # Use novograd and empty opt_args list by default

Override optimizer and args

novograd_args = {betas:(0.95, 0.5), weight_decay:0.001}
parser = add_optimizer_args(parser, optimizer='novograd', default_opt_args=novograd_args)  # Use novograd and custom defaults

Signed-off-by: smajumdar [email protected]

@titu1994 titu1994 marked this pull request as draft July 1, 2020 02:03
@titu1994 titu1994 requested a review from okuchaiev July 1, 2020 02:03
@titu1994 titu1994 force-pushed the candidate_optimizer_refactor branch from 1d1d931 to 36f755a Compare July 1, 2020 05:20
@titu1994 titu1994 force-pushed the candidate_optimizer_refactor branch from 1d10104 to 122fbde Compare July 1, 2020 16:38
Copy link
Member

@okuchaiev okuchaiev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just few minor comments

examples/asr/speech_to_text.py Outdated Show resolved Hide resolved
examples/asr/speech_to_text.py Show resolved Hide resolved
nemo/collections/asr/models/ctc_models.py Outdated Show resolved Hide resolved
nemo/core/classes/optimizers.py Outdated Show resolved Hide resolved
nemo/core/classes/optimizers.py Outdated Show resolved Hide resolved
nemo/core/classes/optimizers.py Outdated Show resolved Hide resolved
nemo/core/classes/optimizers.py Outdated Show resolved Hide resolved
nemo/core/classes/optimizers.py Outdated Show resolved Hide resolved
@okuchaiev okuchaiev requested review from blisc and VahidooX July 1, 2020 22:27
nemo/collections/asr/models/ctc_models.py Outdated Show resolved Hide resolved
Comment on lines 52 to +54
asr_model.setup_training_data(model_config['AudioToTextDataLayer'])
asr_model.setup_validation_data(model_config['AudioToTextDataLayer_eval'])
asr_model.setup_optimization(optim_params={'lr': 0.0003})
asr_model.setup_optimization(optim_params={'optimizer': args.optimizer, 'lr': args.lr, 'opt_args': args.opt_args})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is out of scope of PR, but these three lines look holly out of line with pytorch lightning code. Just do all of this in init(). I fail to see the reason we need to do this separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blisc are you proposing to have models.init() take: (1) model hyper parameters, (2) optimizer hyper parameters and (3) train/test/eval data parameters instead of having setup_* functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blisc I came to exactly the same conclusions yesterday - thus my email.

I think the solution is to properly parametrize NeMo Models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need to manually extract the kwargs from the parsed args, vars(args) is concise and serves the same purpose.

Comment on lines 100 to 101
optimizer = get_optimizer(optimizer_name)
self.__optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not merge these two lines into one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that, I just thought its better to separate in the case that optimizer_name is not valid, and therefore get_optimizer will raise an error. The traceback would point to a pretty dense line in that case. But sure, we can merge it too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't what I had in mind actually. I was thinking more return get_optimizer(optimizer_name, self.parameters(), lr=lr, **optimizer_args), ie I would expect get_optimizer to instantiate an optimizer for me.

If you want to keep your original design, I would actually prefer the old:

optimizer = get_optimizer(optimizer_name)
self.__optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args)

rather than the changed:

optimizer = get_optimizer(optimizer_name)(self.parameters(), lr=lr, **optimizer_args)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I misunderstood. Yes, I'll revert to follow the older design. As to merging the two lines together, I would prefer not to do that for two reasons - 1) we may want the class without instantiation to wrap into another class (say we have experimental optimizer), 2) we want to pass the class as an argument without instantiation to perform defered computation or typecheck in tests.

nemo/collections/asr/models/ctc_models.py Outdated Show resolved Hide resolved
nemo/core/optim/optimizers.py Outdated Show resolved Hide resolved
nemo/core/optim/optimizers.py Show resolved Hide resolved
nemo/core/optim/optimizers.py Outdated Show resolved Hide resolved
@blisc
Copy link
Collaborator

blisc commented Jul 1, 2020

Seems mostly fine to me, just some minor comments

@@ -49,7 +51,7 @@ def main(args):
model_config['AudioToTextDataLayer_eval']['manifest_filepath'] = args.eval_dataset
asr_model.setup_training_data(model_config['AudioToTextDataLayer'])
asr_model.setup_validation_data(model_config['AudioToTextDataLayer_eval'])
asr_model.setup_optimization(optim_params={'lr': 0.0003})
asr_model.setup_optimization(optim_params=vars(args))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we passing all args here?I If yes, I think it does not look nice that we pass all the args here. Can we pass three variables: lr, optimizer_kind and opt_params? like asr_model.setup_optimization(lr=args.lr, optimizer_kind=args.optimizer_kind, optim_params=args.opt_param)?

It makes it easier to understandable by the user. I think it looks like a magic function like this that gets all the args and return the optimizer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is what it was before, and we can go back to that. I'm a bit worried about how many args are going to be passed for optimizer + scheduler but might be better to be explicit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted it to pass args explicitly. I agree, this looks cleaner and more understandable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good point about the scheduler. How about having 5 inputs: 1- optimizer_kind, 2-lr 3-opt_params 4-lr_policy 5-lr_policy_params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. We should be able to do that cleanly. Though for now, I am hard coding scheduler until the team finds this approach to Optimizer good enough to extend to schedulers as well.

nemo/collections/asr/models/ctc_models.py Show resolved Hide resolved
@titu1994 titu1994 marked this pull request as ready for review July 2, 2020 17:12
@okuchaiev okuchaiev merged commit fe14046 into NVIDIA:candidate Jul 2, 2020
@titu1994 titu1994 deleted the candidate_optimizer_refactor branch July 2, 2020 18:31
@blisc blisc mentioned this pull request Jul 6, 2020
15 tasks
dcurran90 pushed a commit to dcurran90/NeMo that referenced this pull request Oct 15, 2024
* In addition to the HTML report which is generated
* The term listing will be seen in CI job output

Signed-off-by: Mark Sturdevant <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants