From 10c96105342f115babfd2031d6a4388a8216cb3a Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 8 Jul 2024 18:37:32 +0800 Subject: [PATCH] [Add] typing --- python/paddle/nn/utils/weight_norm_hook.py | 36 ++++++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/python/paddle/nn/utils/weight_norm_hook.py b/python/paddle/nn/utils/weight_norm_hook.py index 46e314210e0633..d263a1dbd40308 100644 --- a/python/paddle/nn/utils/weight_norm_hook.py +++ b/python/paddle/nn/utils/weight_norm_hook.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + import paddle from paddle import _C_ops @@ -18,10 +23,18 @@ from ...base.layer_helper import LayerHelper from ...framework import in_dynamic_mode +if TYPE_CHECKING: + from typing_extensions import Never + + from paddle import Tensor + from paddle.nn import Layer + __all__ = [] -def l2_norm(x, axis, epsilon=1e-12, name=None): +def l2_norm( + x: Tensor, axis: int, epsilon: float = 1e-12, name: str | None = None +) -> Tensor: if len(x.shape) == 1: axis = 0 @@ -46,7 +59,7 @@ def l2_norm(x, axis, epsilon=1e-12, name=None): return paddle.squeeze(norm, axis=[axis]) -def norm_except_dim(p, dim): +def norm_except_dim(p: Tensor, dim: int) -> Tensor: shape = p.shape ndims = len(shape) if dim == -1: @@ -65,7 +78,7 @@ def norm_except_dim(p, dim): return norm_except_dim(p_transposed, 0) -def _weight_norm(v, g, dim): +def _weight_norm(v: Tensor, g: Tensor, dim: int) -> Tensor: shape = v.shape ndims = len(shape) @@ -96,19 +109,22 @@ def _weight_norm(v, g, dim): class WeightNorm: - def __init__(self, name, dim): + name: str + dim: int + + def __init__(self, name: str, dim: int) -> None: if dim is None: dim = -1 self.name = name self.dim = dim - def compute_weight(self, layer): + def compute_weight(self, layer: Layer) -> Tensor: g = getattr(layer, self.name + '_g') v = getattr(layer, self.name + '_v') return _weight_norm(v, g, self.dim) @staticmethod - def apply(layer, name, dim): + def apply(layer: Layer, name: str, dim: int) -> WeightNorm: for k, hook in layer._forward_pre_hooks.items(): if isinstance(hook, WeightNorm) and hook.name == name: raise RuntimeError( @@ -145,7 +161,7 @@ def apply(layer, name, dim): layer.register_forward_pre_hook(fn) return fn - def remove(self, layer): + def remove(self, layer: Layer) -> None: w_var = self.compute_weight(layer) delattr(layer, self.name) del layer._parameters[self.name + '_g'] @@ -155,11 +171,11 @@ def remove(self, layer): with paddle.no_grad(): paddle.assign(w_var, w) - def __call__(self, layer, inputs): + def __call__(self, layer: Layer, inputs: Never) -> None: setattr(layer, self.name, self.compute_weight(layer)) -def weight_norm(layer, name='weight', dim=0): +def weight_norm(layer: Layer, name: str = 'weight', dim: int = 0) -> Layer: r""" Applies weight normalization to a parameter according to the following formula: @@ -205,7 +221,7 @@ def weight_norm(layer, name='weight', dim=0): return layer -def remove_weight_norm(layer, name='weight'): +def remove_weight_norm(layer: Layer, name: str = 'weight') -> Layer: """ remove weight normalization from layer.