Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-95, C-96, C-99]Add annotations for python/paddle/incubate/optimizer/{lookahead, modelaverage, lbfgs}.py #67448

Merged
merged 17 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
c58b86c
[Typing][C-95]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 14, 2024
9194f48
[Typing][C-96]Add annotations for `python/paddle/incubate/optimizer/m…
inaomIIsfarell Aug 14, 2024
40a7daf
[Typing][C-99]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 14, 2024
7961d6e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
inaomIIsfarell Aug 14, 2024
80e95f9
[Typing][C-99]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 17, 2024
8dd1d08
[Typing][C-99]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 17, 2024
c302da7
[Typing][C-95]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 17, 2024
d912ef0
[Typing][C-96]Add annotations for `python/paddle/incubate/optimizer/m…
inaomIIsfarell Aug 17, 2024
43a4f2b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
inaomIIsfarell Aug 17, 2024
31e6dfd
[Typing][C-95]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 20, 2024
74c38ff
[Typing][C-96]Add annotations for `python/paddle/incubate/optimizer/m…
inaomIIsfarell Aug 20, 2024
e6e0e4e
[Typing][C-99]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 20, 2024
96ae0bf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
inaomIIsfarell Aug 20, 2024
63f2e09
[Typing][C-95]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 21, 2024
3d061ce
[Typing][C-96]Add annotations for `python/paddle/incubate/optimizer/m…
inaomIIsfarell Aug 21, 2024
e602923
[Typing][C-99]Add annotations for `python/paddle/incubate/optimizer/l…
inaomIIsfarell Aug 21, 2024
2b49e70
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
inaomIIsfarell Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions python/paddle/incubate/optimizer/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING, Literal, TypeVar

import paddle
from paddle.optimizer import Optimizer
from paddle.utils import deprecated

from .line_search_dygraph import _strong_wolfe

if TYPE_CHECKING:
from collections.abc import Callable, Sequence

from paddle import Tensor
from paddle.nn.clip import GradientClipBase
from paddle.optimizer.optimizer import _ParameterConfig
from paddle.regularizer import WeightDecayRegularizer

_T_co = TypeVar('_T_co', covariant=True)


@deprecated(since="2.5.0", update_to="paddle.optimizer.LBFGS", level=1)
class LBFGS(Optimizer):
Expand Down Expand Up @@ -116,20 +127,29 @@ class LBFGS(Optimizer):

"""

learning_rate: float
max_iter: int
max_eval: int
tolerance_grad: float
tolerance_change: float
history_size: int
line_search_fn: Literal['strong_wolfe'] | None
state: defaultdict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state: defaultdict
state: dict[str, dict[str, Any]]


def __init__(
self,
learning_rate=1.0,
max_iter=20,
max_eval=None,
tolerance_grad=1e-7,
tolerance_change=1e-9,
history_size=100,
line_search_fn=None,
parameters=None,
weight_decay=None,
grad_clip=None,
name=None,
):
learning_rate: float = 1.0,
max_iter: int = 20,
max_eval: int | None = None,
tolerance_grad: float = 1e-7,
tolerance_change: float = 1e-9,
history_size: int = 100,
line_search_fn: Literal['strong_wolfe'] | None = None,
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None,
weight_decay: float | WeightDecayRegularizer | None = None,
grad_clip: GradientClipBase | None = None,
name: str | None = None,
) -> Tensor:
if max_eval is None:
max_eval = max_iter * 5 // 4

Expand Down Expand Up @@ -165,7 +185,7 @@ def __init__(

self._numel_cache = None

def state_dict(self):
def state_dict(self) -> dict[str, Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def state_dict(self) -> dict[str, Tensor]:
def state_dict(self) -> dict[str, dict[str, Any]]:

r"""Returns the state of the optimizer as a :class:`dict`.

Return:
Expand Down Expand Up @@ -226,7 +246,7 @@ def _directional_evaluate(self, closure, x, alpha, d):
self._set_param(x)
return loss, flat_grad

def step(self, closure):
def step(self, closure: Callable[[], _T_co]) -> _T_co:
"""
Performs a single optimization step.

Expand Down
34 changes: 30 additions & 4 deletions python/paddle/incubate/optimizer/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
from paddle.base import framework, unique_name
Expand All @@ -21,6 +24,13 @@
from paddle.optimizer import Optimizer
from paddle.pir.core import create_parameter

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.framework import Operator
from paddle.base.layer_helper import LayerHelperBase
from paddle.static import Program


__all__ = []


Expand Down Expand Up @@ -112,9 +122,21 @@ class LookAhead(Optimizer):

"""

inner_optimizer: Optimizer
alpha: float
k: int
type: str
helper: LayerHelperBase | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
helper: LayerHelperBase | None
helper: LayerHelper

没必要用 LayerHelperBase 吧?


_slow_str = "slow"

def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None):
def __init__(
self,
inner_optimizer: Optimizer,
alpha: float = 0.5,
k: int = 5,
name: str | None = None,
) -> None:
assert inner_optimizer is not None, "inner optimizer can not be None"
assert (
0.0 <= alpha <= 1.0
Expand Down Expand Up @@ -152,7 +174,7 @@ def _set_auxiliary_var(self, key, val):

@framework.dygraph_only
@imperative_base.no_grad
def step(self):
def step(self) -> None:
"""
Execute the optimizer and update parameters once.

Expand Down Expand Up @@ -272,8 +294,12 @@ def _append_optimize_op(self, block, param_and_grad):

@imperative_base.no_grad
def minimize(
self, loss, startup_program=None, parameters=None, no_grad_set=None
):
self,
loss: Tensor,
startup_program: Program | None = None,
parameters: list[Tensor] | list[str] | None = None,
no_grad_set: set[Tensor] | set[str] | None = None,
) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]:
"""
Add operations to minimize ``loss`` by updating ``parameters``.

Expand Down
50 changes: 38 additions & 12 deletions python/paddle/incubate/optimizer/modelaverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import paddle
from paddle import _C_ops
from paddle.base import framework
from paddle.base.dygraph import base as imperative_base
from paddle.base.layer_helper import LayerHelper
from paddle.base.layer_helper import LayerHelper, LayerHelperBase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from paddle.base.layer_helper import LayerHelper, LayerHelperBase
from paddle.base.layer_helper import LayerHelper

貌似不需要 LayerHelperBase

from paddle.base.wrapped_decorator import signature_safe_contextmanager
from paddle.framework import (
in_dynamic_mode,
Expand All @@ -25,6 +28,15 @@
)
from paddle.optimizer import Optimizer

if TYPE_CHECKING:
from collections.abc import Generator, Sequence

from paddle import Tensor
from paddle.base.framework import Program
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from paddle.base.framework import Program
from paddle.static import Program

from paddle.optimizer.optimizer import _ParameterConfig
from paddle.static import Executor


__all__ = []


Expand Down Expand Up @@ -169,14 +181,22 @@ class ModelAverage(Optimizer):

"""

helper: LayerHelperBase | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
helper: LayerHelperBase | None
helper: LayerHelper

没有必要用 LayerHelperBase 吧?

average_window: float
min_average_window: int
max_average_window: int
type: str
apply_program: Program
restore_program: Program

def __init__(
self,
average_window_rate,
parameters=None,
min_average_window=10000,
max_average_window=10000,
name=None,
):
average_window_rate: float,
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None,
min_average_window: int = 10000,
max_average_window: int = 10000,
name: str | None = None,
) -> None:
super().__init__(
learning_rate=0.0,
parameters=parameters,
Expand Down Expand Up @@ -296,8 +316,12 @@ def _append_optimize_op(self, block, param_and_grad):

@imperative_base.no_grad
def minimize(
self, loss, startup_program=None, parameters=None, no_grad_set=None
):
self,
loss: Tensor,
startup_program: Program | None = None,
parameters: list[Tensor] | None = None,
no_grad_set: set[Tensor] | set[str] | None = None,
) -> None:
"""
Add operations to minimize ``loss`` by updating ``parameters``.

Expand Down Expand Up @@ -350,7 +374,7 @@ def minimize(

@framework.dygraph_only
@imperative_base.no_grad
def step(self):
def step(self) -> None:
"""
Execute the optimizer and update parameters once.

Expand Down Expand Up @@ -395,7 +419,9 @@ def step(self):

@signature_safe_contextmanager
@imperative_base.no_grad
def apply(self, executor=None, need_restore=True):
def apply(
self, executor: Executor | None = None, need_restore: bool = True
) -> Generator[None, None, None]:
"""
Apply the average of the cumulative ``Parameter`` to the parameters of the current model.

Expand Down Expand Up @@ -474,7 +500,7 @@ def apply(self, executor=None, need_restore=True):
self.restore(executor)

@imperative_base.no_grad
def restore(self, executor=None):
def restore(self, executor: Executor | None = None) -> None:
"""
Restore ``Parameter`` values of current model.

Expand Down