-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
63488d6
to
c0508d3
Compare
e1a3ad9
to
8d62300
Compare
@mxnet-label-bot add [pr-awaiting-review] |
dfdfdab
to
e656220
Compare
|
||
@register | ||
class LAMB(Optimizer): | ||
"""LAMB Optimizer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls add doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
working on it now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is clashing with the GLuon one, can we give it a different name?
32f26cc
to
e6ac0dc
Compare
f720d46
to
9beed69
Compare
1eb581a
to
b6851c9
Compare
Please add description and reference to paper in the PR. |
I see a crash in the way gluon trainer is calling the optimizer...
|
Done |
@access2rohit could you answer Sam's comments? |
Done |
float beta1; | ||
float beta2; | ||
float epsilon; | ||
float t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eric-haibin-lin @access2rohit I find this issue when reading the code. Here, the t
should be the number of updates and should not be stored as float
, which will lose the precision. I think we need to store it as index_t.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are using float here for integer data type. @sxjscience can you explain how we will loses precision for the operation beta^t
?
|
||
if (bias_correction) { | ||
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); | ||
DType var_hat = var_data[i] / (1 - power::Map(beta2, t)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, in apex, it uses a float32 to calculate the power and then switch to float16:
https://github.com/NVIDIA/apex/blob/325f5a0bec542701edba1628ad34f3b2ea47c556/csrc/multi_tensor_lamb.cu#L231-L249
* initial commit lamb optimizer * fixing base lamb optimizer * adding API doc for Lamb Phase 1 and 2
* initial commit lamb optimizer * fixing base lamb optimizer * adding API doc for Lamb Phase 1 and 2
* initial commit lamb optimizer * fixing base lamb optimizer * adding API doc for Lamb Phase 1 and 2
Description
adding to new operators:
Link to paper: https://arxiv.org/pdf/1904.00962.pdf
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Testing