Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-46] Add type annotations for paddle/optimizer/nadam.py #65273

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions python/paddle/optimizer/nadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Sequence

from paddle import _C_ops
from paddle.base.libpaddle import DataType
Expand All @@ -24,6 +27,22 @@
)
from .optimizer import Optimizer

if TYPE_CHECKING:
from typing_extensions import NotRequired

from paddle import Tensor
from paddle.nn.clip import GradientClipBase

from .lr import LRScheduler
from .optimizer import _ParameterConfig

class _NAdamParameterConfig(_ParameterConfig):
beta1: NotRequired[float | Tensor]
beta2: NotRequired[float | Tensor]
epsilon: NotRequired[float]
momentum_decay: NotRequired[float]


__all__ = []


Expand Down Expand Up @@ -52,7 +71,7 @@ class NAdam(Optimizer):
Args:
learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
It can be a float value or a LRScheduler. The default value is 0.002.
parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``.
parameters (list|tuple|None, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``.
This parameter is required in dygraph mode. And you can specify different options for
different parameter groups such as the learning rate, weight decay, etc,
then the parameters are list of dict. Note that the learning_rate in parameter groups
Expand All @@ -66,14 +85,14 @@ class NAdam(Optimizer):
The default value is 0.999.
epsilon (float, optional): A small float value for numerical stability.
The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor.
weight_decay (float|Tensor|None, optional): The weight decay coefficient, it can be float or Tensor.
Default None, meaning there is no regularization.
momentum_decay (float, optional): momentum momentum_decay. The default value is 0.004.
grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of
grad_clip (GradientClipBase|None, optional): Gradient clipping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three clipping strategies
( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` ,
:ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): Normally there is no need for user to set this property.
name (str|None, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.

Expand Down Expand Up @@ -105,7 +124,7 @@ class NAdam(Optimizer):
>>> loss = paddle.mean(out)
>>> opt = paddle.optimizer.NAdam(
... learning_rate=0.1,
... parameters=[{
... parameters=[{ # type: ignore
... 'params': linear_1.parameters()
... }, {
... 'params': linear_2.parameters(),
Expand All @@ -130,16 +149,18 @@ class NAdam(Optimizer):

def __init__(
self,
learning_rate=0.002,
beta1=0.9,
beta2=0.999,
epsilon=1.0e-8,
momentum_decay=0.004,
parameters=None,
weight_decay=None,
grad_clip=None,
name=None,
):
learning_rate: float | LRScheduler = 0.002,
beta1: float | Tensor = 0.9,
beta2: float | Tensor = 0.999,
epsilon: float = 1.0e-8,
momentum_decay: float = 0.004,
parameters: Sequence[Tensor]
| Sequence[_NAdamParameterConfig]
| None = None,
weight_decay: float | Tensor | None = None,
grad_clip: GradientClipBase | None = None,
name: str | None = None,
) -> None:
if isinstance(learning_rate, (float, int)) and not 0.0 <= learning_rate:
raise ValueError(
f"Invalid learning rate: {learning_rate}, expect learning_rate >= 0."
Expand Down