Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-46] Add type annotations for paddle/optimizer/nadam.py #65273

Merged
merged 4 commits into from
Jun 21, 2024

Conversation

enkilee
Copy link
Contributor

@enkilee enkilee commented Jun 19, 2024

PR Category

User Experience

PR Types

Improvements

Description

Copy link

paddle-bot bot commented Jun 19, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Jun 19, 2024
from .optimizer import Optimizer, _ParameterConfig


class _NAdamParameterConfig(_ParameterConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_AdamParameterConfig 的区别是?为什么要重新定义呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看adam这样定义,但是sgd和其他的没有,示例代码里面

            ...     parameters=[{
            ...         'params': linear_1.parameters()
            ...     }, {
            ...         'params': linear_2.parameters(),
            ...         'weight_decay': 0.001,
            ...         'learning_rate': 0.1,
            ...         'beta1': 0.8
            ...     }],

这样,所以要这样改吗?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@megemini 看看这个 API 呢?是否可以直接复用 _AdamParameterConfig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

又翻了一下代码,AdamNAdam 的 ParameterConfig 应该都需要再补充一下 ~

主要关注一下 _update_param_group 这个方法,先看 python/paddle/optimizer/adam.py :

    def _update_param_group(self, parameters):
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._lazy_mode = parameters.get(
            'lazy_mode', self._default_dict['lazy_mode']
        )
        parameters = parameters.get('params')
        return parameters

因此,_AdamParameterConfig 感觉应该是:

class _AdamParameterConfig(_ParameterConfig):
    beta1: NotRequired[float | Tensor]
    beta2: NotRequired[float | Tensor]
    epsilon: NotRequired[float]
    lazy_mode: NotRequired[bool]

再看 python/paddle/optimizer/nadam.py 的:

    def _update_param_group(self, parameters):
        self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
        self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
        self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
        self._momentum_decay = parameters.get(
            'momentum_decay', self._default_dict['momentum_decay']
        )
        parameters = parameters.get('params')
        return parameters

因此,_NAdamParameterConfig 应该是:

class _NAdamParameterConfig(_ParameterConfig):
    beta1: NotRequired[float | Tensor]
    beta2: NotRequired[float | Tensor]
    epsilon: NotRequired[float]
    momentum_decay: NotRequired[float]

是不?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,按照这个会更准确些~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我检查一下已经合入的那几个接口,看看要不要改一下吧 ~

@SigureMo
Copy link
Member

2024-06-19 16:24:02 ----------------Check results--------------------
2024-06-19 16:24:02 paddle.optimizer.NAdam:1
2024-06-19 16:24:02 
2024-06-19 16:24:02 
2024-06-19 16:24:02 >>> Mistakes found in type checking!
2024-06-19 16:24:02 >>> Please recheck the type annotations. Run `tools/type_checking.py` to check the typing issues:
2024-06-19 16:24:02 > python tools/type_checking.py paddle.optimizer.NAdam
2024-06-19 16:24:02 ----------------End of the Check--------------------
2024-06-19 16:24:02 ----------------End of the Check--------------------

奇怪,这错误报告咋是空的?

@megemini
Copy link
Contributor

2024-06-19 16:24:02 ----------------Check results--------------------
2024-06-19 16:24:02 paddle.optimizer.NAdam:1
2024-06-19 16:24:02 
2024-06-19 16:24:02 
2024-06-19 16:24:02 >>> Mistakes found in type checking!
2024-06-19 16:24:02 >>> Please recheck the type annotations. Run `tools/type_checking.py` to check the typing issues:
2024-06-19 16:24:02 > python tools/type_checking.py paddle.optimizer.NAdam
2024-06-19 16:24:02 ----------------End of the Check--------------------
2024-06-19 16:24:02 ----------------End of the Check--------------------

奇怪,这错误报告咋是空的?

我刚才本地试了一下,第一次检查是失败的,之后再检查就没问题了:

第一次检查

image

之后检查

image

这里第一次检查出现了明显与此接口无关的模块,还不清楚是什么导致的,清空 mypy 的 cache 也无法复现第一次出现的问题 ~

是不是可以把 type_checking 的 debug 先打开?

@megemini
Copy link
Contributor

另外,日志里面报 mypy 出错了

2024-06-19 16:24:02 /usr/local/lib/python3.10/dist-packages/paddle/optimizer/nadam.py:150: error: INTERNAL ERROR -- Please try using mypy master on GitHub:
2024-06-19 16:24:02 https://mypy.readthedocs.io/en/stable/common_issues.html#using-a-development-mypy-build
2024-06-19 16:24:02 If this issue continues with mypy master, please report a bug at https://github.com/python/mypy/issues

跟这个也有关?

@megemini
Copy link
Contributor

刚才定位了一下,应该是之前合入的 Optimizer 有点问题:

class _ParameterConfig(TypedDict):
    params: Sequence[Tensor]
    weight_decay: NotRequired[float | WeightDecayRegularizer | None]
    learning_rate: NotRequired[float | Tensor | LRScheduler | None]


if TYPE_CHECKING:
    from paddle import Tensor
    from paddle.nn.clip import GradientClipBase

    from ..base.framework import Operator, Program

这里 _ParameterConfig 用到了 Tensor ,但是 Tensor 确在后面导入的 ~ 我在本地调试了一下 (开启 debug 模式, mypy 使用 --show-traceback,参考 pr https://github.com/PaddlePaddle/Paddle/pull/65319):

(venv38dev)  ✘ shun@shun-B660M-Pro-RS  ~/Documents/Projects/paddle/megemini/Paddle/tools   typing_debug_mode  python type_checking.py --debug paddle.optimizer.NAdam
----------------Codeblock Type Checking Start--------------------
>>> Get docstring from api ...
API_PR is diff from API_DEV: dict_keys(['paddle.optimizer.NAdam'])
Total api: 1
>>> Running type checker ...
/home/shun/venv38dev/lib/python3.8/site-packages/paddle/optimizer/nadam.py:150: error: INTERNAL ERROR -- Please try using mypy master on GitHub:
https://mypy.readthedocs.io/en/stable/common_issues.html#using-a-development-mypy-build
Please report a bug at https://github.com/python/mypy/issues
version: 1.10.0
Traceback (most recent call last):
  File "type_checking.py", line 325, in <module>
    run_type_checker(args, mypy_checker)
  File "type_checking.py", line 304, in run_type_checker
    test_results = get_test_results(type_checker, docstrings_to_test)
  File "type_checking.py", line 272, in get_test_results
    with multiprocessing.Pool(initializer=init_worker) as pool:
  File "/usr/lib/python3.8/multiprocessing/context.py", line 119, in Pool
    return Pool(processes, initializer, initargs, maxtasksperchild,
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 212, in __init__
    self._repopulate_pool()
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 303, in _repopulate_pool
    return self._repopulate_pool_static(self._ctx, self.Process,
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 326, in _repopulate_pool_static
    w.start()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/usr/lib/python3.8/multiprocessing/context.py", line 277, in _Popen
    return Popen(process_obj)
  File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 75, in _launch
    code = process_obj._bootstrap(parent_sentinel=child_r)
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "mypy/semanal.py", line 6714, in accept
  File "mypy/nodes.py", line 787, in accept
  File "mypy/semanal.py", line 835, in visit_func_def
  File "mypy/semanal.py", line 870, in analyze_func_def
  File "mypy/semanal.py", line 6386, in defer
AssertionError: Must not defer during final iteration
/home/shun/venv38dev/lib/python3.8/site-packages/paddle/optimizer/nadam.py:150: : note: use --pdb to drop into pdb
--------------------
>>> Type hints with api paddle.optimizer.NAdam:1 start ...
import paddle
inp = paddle.rand([10,10], dtype="float32")
linear = paddle.nn.Linear(10, 10)
out = linear(inp)
loss = paddle.mean(out)
nadam = paddle.optimizer.NAdam(learning_rate=0.1,
                    parameters=linear.parameters())
out.backward()
nadam.step()
nadam.clear_grad()
linear_1 = paddle.nn.Linear(10, 10)
linear_2 = paddle.nn.Linear(10, 10)
inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
out = linear_1(inp)
out = linear_2(out)
loss = paddle.mean(out)
opt = paddle.optimizer.NAdam(
    learning_rate=0.1,
    parameters=[{ # type: ignore
        'params': linear_1.parameters()
    }, {
        'params': linear_2.parameters(),
        'weight_decay': 0.001,
        'learning_rate': 0.1,
        'beta1': 0.8
    }],
    weight_decay=0.01,
    beta1=0.9
)
loss.backward()
opt.step()
opt.clear_grad()
>>> Results ...
>>> mypy normal_report is ...

>>> mypy error_report is ...

>>> mypy exit_status is ...
2
>>> Type hints with api paddle.optimizer.NAdam:1 end...
>>> Print summary ...
----------------Check results--------------------
----------------Check results--------------------
paddle.optimizer.NAdam:1
paddle.optimizer.NAdam:1




>>> Mistakes found in type checking!
>>> Mistakes found in type checking!
>>> Please recheck the type annotations. Run `tools/type_checking.py` to check the typing issues:
>>> Please recheck the type annotations. Run `tools/type_checking.py` to check the typing issues:
> python tools/type_checking.py paddle.optimizer.NAdam
> python tools/type_checking.py paddle.optimizer.NAdam
----------------End of the Check--------------------
----------------End of the Check--------------------
----------------End of the Check--------------------
----------------End of the Check--------------------

日志里面最后也是没有显示错误信息,只有 debug 之后可以看到,应该是 defer 顺序错了 ~

我在 pr #65277 中已经修改了,@enkilee 可以先 merge 一下最新的 commit 再提交看看 ~

@SigureMo
Copy link
Member

CI 上按理说会自动 merge,我先 rerun 下试试

@SigureMo
Copy link
Member

@enkilee rerun 后没变,merge 一下试试吧

@enkilee
Copy link
Contributor Author

enkilee commented Jun 20, 2024

@enkilee rerun 后没变,merge 一下试试吧

收到

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ~ 🤟

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit 3221677 into PaddlePaddle:develop Jun 21, 2024
33 checks passed
@SigureMo SigureMo added the HappyOpenSource 快乐开源活动issue与PR label Jun 21, 2024
@enkilee enkilee deleted the typing-a46-nadam branch June 25, 2024 02:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants