Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for NeMo scope Optimizers support and add Novograd Optimizer #793

Merged
merged 15 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/asr/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# --distributed_backend "ddp" \
# --max_epochs 1 \
# --fast_dev_run \
# --lr 0.001 \

from argparse import ArgumentParser

Expand All @@ -29,6 +30,7 @@

from nemo.collections.asr.arguments import add_asr_args
from nemo.collections.asr.models import EncDecCTCModel
from nemo.core.classes.optimizers import add_optimizer_args
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved


def main(args):
Expand All @@ -49,7 +51,7 @@ def main(args):
model_config['AudioToTextDataLayer_eval']['manifest_filepath'] = args.eval_dataset
asr_model.setup_training_data(model_config['AudioToTextDataLayer'])
asr_model.setup_validation_data(model_config['AudioToTextDataLayer_eval'])
asr_model.setup_optimization(optim_params={'lr': 0.0003})
asr_model.setup_optimization(optim_params={'optimizer': args.optimizer, 'lr': args.lr, 'opt_args': args.opt_args})
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines 52 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is out of scope of PR, but these three lines look holly out of line with pytorch lightning code. Just do all of this in init(). I fail to see the reason we need to do this separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blisc are you proposing to have models.init() take: (1) model hyper parameters, (2) optimizer hyper parameters and (3) train/test/eval data parameters instead of having setup_* functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blisc I came to exactly the same conclusions yesterday - thus my email.

I think the solution is to properly parametrize NeMo Models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need to manually extract the kwargs from the parsed args, vars(args) is concise and serves the same purpose.

# trainer = pl.Trainer(
# val_check_interval=1, amp_level='O1', precision=16, gpus=4, max_epochs=123, distributed_backend='ddp'
# )
Expand All @@ -62,6 +64,7 @@ def main(args):
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = add_asr_args(parser)
parser = add_optimizer_args(parser)

args = parser.parse_args()

Expand Down
27 changes: 24 additions & 3 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from nemo.collections.asr.metrics.wer import monitor_asr_train_progress
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.parts.features import WaveformFeaturizer
from nemo.core.classes.common import Serialization, typecheck
from nemo.core.classes.common import Serialization, logging, typecheck
from nemo.core.classes.optimizers import get_optimizer, parse_optimizer_args
from nemo.core.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import experimental
Expand Down Expand Up @@ -79,7 +80,28 @@ def setup_test_data(self, test_data_layer_params: Optional[Dict]):
self.__test_dl = self.__setup_dataloader_from_config(config=test_data_layer_params)

def setup_optimization(self, optim_params: Optional[Dict]):
blisc marked this conversation as resolved.
Show resolved Hide resolved
self.__optimizer = torch.optim.Adam(self.parameters(), lr=optim_params['lr'])
optim_params = optim_params or {} # In case null was passed as optim_params

# Check if caller provided optimizer name, default to Adam otherwise
optimizer_name = optim_params.get('optimizer', 'adam')

# Check if caller has optimizer kwargs, default to empty dictionary
optimizer_args = optim_params.get('opt_args', [])
optimizer_args = parse_optimizer_args(optimizer_args)

# We are guarenteed to have lr since it is required by the argparser
# But maybe user forgot to pass it to this function
lr = optim_params.get('lr', None)

if 'lr' is None:
raise ValueError('`lr` must be passed when setting up the optimization !')

# Actually instantiate the optimizer
optimizer = get_optimizer(optimizer_name)
self.__optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args)
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not merge these two lines into one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that, I just thought its better to separate in the case that optimizer_name is not valid, and therefore get_optimizer will raise an error. The traceback would point to a pretty dense line in that case. But sure, we can merge it too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't what I had in mind actually. I was thinking more return get_optimizer(optimizer_name, self.parameters(), lr=lr, **optimizer_args), ie I would expect get_optimizer to instantiate an optimizer for me.

If you want to keep your original design, I would actually prefer the old:

optimizer = get_optimizer(optimizer_name)
self.__optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args)

rather than the changed:

optimizer = get_optimizer(optimizer_name)(self.parameters(), lr=lr, **optimizer_args)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I misunderstood. Yes, I'll revert to follow the older design. As to merging the two lines together, I would prefer not to do that for two reasons - 1) we may want the class without instantiation to wrap into another class (say we have experimental optimizer), 2) we want to pass the class as an argument without instantiation to perform defered computation or typecheck in tests.


# TODO: Remove after demonstration
logging.info("Optimizer config = %s", str(self.__optimizer))
blisc marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def list_available_models(cls) -> Optional[Dict[str, str]]:
Expand Down Expand Up @@ -176,7 +198,6 @@ def training_step(self, batch, batch_nb):
def validation_step(self, batch, batch_idx):
self.eval()
audio_signal, audio_signal_len, transcript, transcript_len = batch
logging.info("Performing forward of validation step")
log_probs, encoded_len, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)
# loss_value = self.loss.loss_function(
loss_value = self.loss(
Expand Down
7 changes: 7 additions & 0 deletions nemo/core/classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@
from nemo.core.classes.loss import Loss
from nemo.core.classes.modelPT import ModelPT
from nemo.core.classes.module import NeuralModule
from nemo.core.classes.optimizers import (
Novograd,
add_optimizer_args,
get_optimizer,
parse_optimizer_args,
register_optimizer,
)
281 changes: 281 additions & 0 deletions nemo/core/classes/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from argparse import ArgumentParser
from functools import partial

import torch
import torch.optim as optim
from torch.optim import adadelta, adagrad, adamax, rmsprop, rprop
from torch.optim.optimizer import Optimizer

__all__ = ['Novograd', 'get_optimizer', 'register_optimizer', 'parse_optimizer_args', 'add_optimizer_args']


AVAILABLE_OPTIMIZERS = {
'sgd': optim.SGD,
'adam': optim.Adam,
'adamw': optim.AdamW,
'adadelta': adadelta.Adadelta,
'adamax': adamax.Adamax,
'adagrad': adagrad.Adagrad,
'rmsprop': rmsprop.RMSprop,
'rprop': rprop.Rprop,
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
}


def _boolify(s):
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
if s == 'True' or s == 'true':
return True
if s == 'False' or s == 'false':
return False
raise ValueError('Not Boolean Value!')


def _autocast(value):
# if value is itself None, dont parse
if value is None:
return None

# If value is comma seperated list of items, recursively parse all items in list.
if "," in value:
values = value.split(',')
values = [_autocast(value) for value in values]
return values

# If value is string `none` or `None`, parse as None
if value == 'none' or 'None':
return None

# Try type cast and return
for cast_type in (int, float, _boolify):
try:
return cast_type(value)
except Exception:
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
pass

# All types failed, return str without casting
return value # str type


def _check_valid_opt_params(lr, eps, betas):
if lr < 0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0:
raise ValueError(f"Invalid epsilon value: {eps}")
if not (0.0 <= betas[0] < 1.0 and 0.0 <= betas[1] < 1.0):
raise ValueError(f"Betas have to be between 0 and 1: {betas}")


def parse_optimizer_args(optimizer_kwargs):
kwargs = {}

if optimizer_kwargs is None:
return kwargs

# If it is a pre-defined dictionary, just return its values
if hasattr(optimizer_kwargs, 'keys'):
return optimizer_kwargs

# If it is key=value string list, parse all items
for key_value in optimizer_kwargs:
key, str_value = key_value.split('=')

value = _autocast(str_value)
kwargs[key] = value

return kwargs


def add_optimizer_args(parent_parser: ArgumentParser, optimizer='adam', default_opt_args=None) -> ArgumentParser:
"""Extends existing argparse with support for optimizers.

Args:
parent_parser (ArgumentParser): Custom CLI parser that will be extended.
optimizer (str): Default optimizer required.
default_opt_args (list(str)): List of overriding arguments for the instantiated optimizer.

Returns:
ArgumentParser: Parser extended by Optimizers arguments.
"""
if default_opt_args is None:
default_opt_args = []

parser = ArgumentParser(parents=[parent_parser], add_help=True, conflict_handler='resolve')

parser.add_argument('--optimizer', type=str, default=optimizer, help='Name of the optimizer. Defaults to Adam.')
parser.add_argument('--lr', type=float, required=True, help='Learning rate of the optimizer.')
parser.add_argument(
'--opt_args',
default=default_opt_args,
nargs='+',
type=str,
help='Overriding arguments for the optimizer. \n'
'Must follow the pattern : \n'
'name=value seperated by spaces.',
)

return parser


def register_optimizer(name, optimizer: Optimizer):
if name in AVAILABLE_OPTIMIZERS:
raise ValueError(f"Cannot override pre-existing optimizers. Conflicting optimizer name = {name}")

AVAILABLE_OPTIMIZERS[name] = optimizer


def get_optimizer(name, **kwargs):
if name not in AVAILABLE_OPTIMIZERS:
raise ValueError(
f"Cannot resolve optimizer '{name}'. Available optimizers are : " f"{AVAILABLE_OPTIMIZERS.keys()}"
)

optimizer = AVAILABLE_OPTIMIZERS[name]
optimizer = partial(optimizer, **kwargs)
return optimizer


def master_params(optimizer):
"""
Generator expression that iterates over the params owned by ``optimizer``.
Args:
optimizer: An optimizer previously returned from ``amp.initialize``.
"""
for group in optimizer.param_groups:
for p in group['params']:
yield p


class Novograd(Optimizer):
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
"""Implements Novograd algorithm.
It has been proposed in "Stochastic Gradient Methods with Layer-wise
Adaptive Moments for Training of Deep Networks"
(https://arxiv.org/abs/1905.11286)
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and Beyond"
"""

def __init__(
self,
params,
lr=1e-3,
betas=(0.95, 0.98),
eps=1e-8,
weight_decay=0,
grad_averaging=False,
amsgrad=False,
luc=False,
luc_trust=1e-3,
luc_eps=1e-8,
):
_check_valid_opt_params(lr, eps, betas)
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad,
)
self.luc = luc
self.luc_trust = luc_trust
self.luc_eps = luc_eps
super(Novograd, self).__init__(params, defaults)

def __setstate__(self, state):
super(Novograd, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Sparse gradients are not supported.")
amsgrad = group["amsgrad"]
state = self.state[p]

# State initialization
if not state:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)
if amsgrad:
# Maintains max of all exp moving avg of squared grad
state["max_exp_avg_sq"] = torch.zeros([]).to(state["exp_avg"].device)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1

norm = grad.norm().pow(2)

if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)

if amsgrad:
# Maintains max of all 2nd moment running avg till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])

grad.div_(denom)
if group["weight_decay"] != 0:
grad.add_(group["weight_decay"], p.data)
if group["grad_averaging"]:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)

if self.luc:
# Clip update so that updates are less than eta*weights
data_norm = torch.norm(p.data)
grad_norm = torch.norm(exp_avg.data)
luc_factor = self.luc_trust * data_norm / (grad_norm + self.luc_eps)
luc_factor = min(luc_factor, group["lr"])
p.data.add_(-luc_factor, exp_avg)
else:
p.data.add_(-group["lr"], exp_avg)

return loss


# Register Novograd
register_optimizer('novograd', Novograd)