Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-95, C-96, C-99]Add annotations for python/paddle/incubate/optimizer/{lookahead, modelaverage, lbfgs}.py #67448

Merged
merged 17 commits into from
Aug 21, 2024

Conversation

inaomIIsfarell
Copy link
Contributor

@inaomIIsfarell inaomIIsfarell commented Aug 14, 2024

PR Category

User Experience

PR Types

Improvements

Description

类型标注:

  • python/paddle/incubate/optimizer/{lookahead, modelaverage, bfgs}.py

Related links

@SigureMo @megemini

Copy link

paddle-bot bot commented Aug 14, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Aug 14, 2024
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Aug 15, 2024
Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个类中实例的参数都没写,如:

class A:
    a: int
    def __init__(self)->None:
        self.a = 1
        self._b = 'b'

上面 A 中的 a ~ 只需要添加可公开属性即可,如 self._b 不需要标注 ~

另外,PR 标题中的标号写错了,应该是 C-95


from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING, Callable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从 collections.abc 导入 Callable

tolerance_change: float = 1e-9,
history_size: int = 100,
line_search_fn: str | None = None,
parameters: list[Tensor] | tuple[Tensor] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parameters: list[Tensor] | tuple[Tensor] | None = None,
parameters: Sequence[Tensor] | Sequence[_ParameterConfig] | None = None,

参考 Optimizer 的初始化参数

@@ -226,7 +232,7 @@ def _directional_evaluate(self, closure, x, alpha, d):
self._set_param(x)
return loss, flat_grad

def step(self, closure):
def step(self, closure: Callable) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def step(self, closure: Callable) -> Tensor:
def step(self, closure: Callable[[], _T]) -> _T:

_T = TypeVar('_T', covariant=True)

name=None,
):
average_window_rate: float,
parameters: list[Tensor] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

startup_program: Program | None = None,
parameters: list[Tensor] | list[str] | None = None,
no_grad_set: set[Tensor] | set[str] | None = None,
) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]:
) -> None:

应该是方法没写完整

def apply(self, executor=None, need_restore=True):
def apply(
self, executor: Executor | None = None, need_restore: bool = True
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
):
) -> Generator[None, None, None]:

megemini

This comment was marked as duplicate.

@inaomIIsfarell inaomIIsfarell changed the title [Typing][C-96, C-96, C-99]Add annotations for python/paddle/incubate/optimizer/{lookahead, modelaverage, lbfgs].py [Typing][C-95, C-96, C-99]Add annotations for python/paddle/incubate/optimizer/{lookahead, modelaverage, lbfgs}.py Aug 15, 2024
tolerance_change: float
history_size: int
line_search_fn: str | None
state: dict[_KT, _VT]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 _KT_VT 的意义是什么,单独使用一个泛型是没有任何约束的,和 dict[Any, Any] 没有区别

tolerance_grad: float
tolerance_change: float
history_size: int
line_search_fn: str | None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
line_search_fn: str | None
line_search_fn: Literal['strong_wolfe'] | None

下同

weight_decay: NotRequired[float | WeightDecayRegularizer | None]
learning_rate: NotRequired[float | Tensor | LRScheduler | None]

_T = TypeVar('_T', covariant=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_T = TypeVar('_T', covariant=True)
_T_co = TypeVar('_T_co', covariant=True)

协变、逆变应该从名字上直接体现


import paddle
from paddle.base import framework, unique_name
from paddle.base.dygraph import base as imperative_base
from paddle.base.framework import Variable
from paddle.base.layer_helper import LayerHelper
from paddle.base.layer_helper import LayerHelper, LayerHelperBase
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非运行时相关 import
放到 if TYPE_CHECKING:

from paddle.framework import in_pir_mode
from paddle.optimizer import Optimizer
from paddle.pir.core import create_parameter

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.framework import Operator, Program
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Program 从 paddle.static 下 import

from paddle.regularizer import WeightDecayRegularizer
from paddle.static import Executor

class _ParameterConfig(TypedDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个 _ParameterConfig,有什么理由不复用 python/paddle/optimizer/optimizer.py 中的 _ParameterConfig

tolerance_change: float
history_size: int
line_search_fn: Literal['strong_wolfe'] | None
state: defaultdict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state: defaultdict
state: dict[str, dict[str, Any]]

@@ -165,7 +185,7 @@ def __init__(

self._numel_cache = None

def state_dict(self):
def state_dict(self) -> dict[str, Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def state_dict(self) -> dict[str, Tensor]:
def state_dict(self) -> dict[str, dict[str, Any]]:

alpha: float
k: int
type: str
helper: LayerHelperBase | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
helper: LayerHelperBase | None
helper: LayerHelper

没必要用 LayerHelperBase 吧?


import paddle
from paddle import _C_ops
from paddle.base import framework
from paddle.base.dygraph import base as imperative_base
from paddle.base.layer_helper import LayerHelper
from paddle.base.layer_helper import LayerHelper, LayerHelperBase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from paddle.base.layer_helper import LayerHelper, LayerHelperBase
from paddle.base.layer_helper import LayerHelper

貌似不需要 LayerHelperBase

from collections.abc import Generator, Sequence

from paddle import Tensor
from paddle.base.framework import Program
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from paddle.base.framework import Program
from paddle.static import Program

@@ -169,14 +181,22 @@ class ModelAverage(Optimizer):

"""

helper: LayerHelperBase | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
helper: LayerHelperBase | None
helper: LayerHelper

没有必要用 LayerHelperBase 吧?

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ~

@luotao1 luotao1 merged commit 99daec4 into PaddlePaddle:develop Aug 21, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants