From d23ea4ef8ebc637534c5abafca995be257f83751 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 12 Jan 2018 16:47:57 +0800 Subject: [PATCH 01/14] add gradient clip by norm --- python/paddle/v2/fluid/clip.py | 12 ++++++++++++ python/paddle/v2/fluid/layers/ops.py | 1 + 2 files changed, 13 insertions(+) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index b1fd1c2b65f10..eb75018d7798f 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -77,6 +77,18 @@ def create_operators(self, param, grad): return param, new_grad +class GradientClipByNorm(BaseGradientClipAttr): + def __init__(self, clip_norm): + self.clip_norm = clip_norm + + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm) + return param, new_grad + + def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index d3a5b70785947..884e84011d960 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -16,6 +16,7 @@ 'elementwise_sub', 'elementwise_mul', 'clip', + 'clip_by_norm', 'sequence_softmax', ] + __activations__ From adc26dffa9dac81bd93c88d70f0ab66fcdcc81f0 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 15 Jan 2018 10:36:09 +0800 Subject: [PATCH 02/14] developing GradientClipByGlobalNorm --- python/paddle/v2/fluid/clip.py | 54 ++++++++++++++++++++++++---- python/paddle/v2/fluid/layers/ops.py | 20 ++++------- 2 files changed, 53 insertions(+), 21 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index eb75018d7798f..f0904e18ea346 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -1,5 +1,6 @@ import functools import layers +from framework import Variable from . import core __all__ = [ @@ -44,7 +45,7 @@ def error_clip_callback(block, context): class BaseGradientClipAttr(object): - def process_context(self, context, p_g): + def process_context(self, context, param, grad): raise NotImplementedError() def create_operators(self, param, grad): @@ -52,7 +53,7 @@ def create_operators(self, param, grad): class NullGradientClipAttr(BaseGradientClipAttr): - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -69,7 +70,7 @@ def __init__(self, max, min=None): self.max = max self.min = min - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -81,7 +82,7 @@ class GradientClipByNorm(BaseGradientClipAttr): def __init__(self, clip_norm): self.clip_norm = clip_norm - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -89,6 +90,46 @@ def create_operators(self, param, grad): return param, new_grad +class GradientClipByGlobalNorm(BaseGradientClipAttr): + global_norm_var = None + clip_norm_var = None + ratio_var = None + + @classmethod + def init(cls, clip_norm): + cls.global_norm_var = layers.fill_constant( + shape=[1], dtype="float32", value=0.0) + cls.clip_norm_var = layers.fill_constant( + shape=[1], dtype="float32", value=clip_norm) + + def __init__(self): + if not (isinstance(self.__class__.global_norm_var, Variable) and + isinstance(self.__class__.clip_norm_var, Variable)): + raise ValueError( + "Class 'GradientClipByGlobalNorm' has not been properly initialized. Please call GradientClipByGlobalNorm.init() first." + ) + + def process_context(self, context, param, grad): + local_norm_var = layers.reduce_sum( + x=layers.pow(x=grad, factor=2), reduce_all=True) + layers.sums( + input=[local_norm_var, self.__class__.global_norm_var], + out=[self.__class__.global_norm_var]) + + def create_operators(self, param, grad): + if self.__class__.ratio_var is None: + self.__class__.global_norm_var = layers.sqrt( + x=self.__class__.global_norm_var) + self.__class__.ratio_var = layers.elementwise_div( + x=self.__class__.clip_norm_var, + y=layers.elementwise_max( + x=self.__class__.clip_norm_var, + y=self.__class__.global_norm_var)) + # 缺乏elementwise_max + # 没法将ratio_var送给scale_op。 + # new_grad = layers. + + def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] @@ -98,10 +139,9 @@ def append_gradient_clip_ops(param_grad): clip_attr = NullGradientClipAttr() if not isinstance(clip_attr, BaseGradientClipAttr): raise TypeError( - "clip attribute should be an instance of BaseGradientClippingAttr" - ) + "clip attribute should be an instance of BaseGradientClipAttr") - clip_attr.process_context(context=context, p_g=param_grad) + clip_attr.process_context(context=context, param=p, grad=g) create_op_callbacks.append( functools.partial( clip_attr.create_operators, param=p, grad=g)) diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 884e84011d960..021b87828f3ae 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -1,23 +1,15 @@ from ..registry import register_layer __activations__ = [ - 'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round' + 'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round', + 'pow' ] __all__ = [ - 'mean', - 'mul', - 'reshape', - 'scale', - 'transpose', - 'sigmoid_cross_entropy_with_logits', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'clip', - 'clip_by_norm', - 'sequence_softmax', + 'mean', 'mul', 'reshape', 'scale', 'transpose', + 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', + 'elementwise_sub', 'elementwise_mul', 'clip', 'clip_by_norm', + 'sequence_softmax', 'reduce_sum' ] + __activations__ for _OP in set(__all__): From f189ad74426cf0970bd05016d4a2827ea6c1ea00 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 17:27:54 +0800 Subject: [PATCH 03/14] refine the defination of class GradientClipByGlobalNorm --- python/paddle/v2/fluid/clip.py | 47 +++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index f0904e18ea346..fcdd4c29e41f7 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -93,41 +93,48 @@ def create_operators(self, param, grad): class GradientClipByGlobalNorm(BaseGradientClipAttr): global_norm_var = None clip_norm_var = None - ratio_var = None + scale_var = None @classmethod def init(cls, clip_norm): + if not (isinstance(clip_norm, int) or isinstance(clip_norm, float)): + raise TypeError("The 'clip_norm' must be a value of int or float") + cls.global_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=0.0) cls.clip_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=clip_norm) - def __init__(self): - if not (isinstance(self.__class__.global_norm_var, Variable) and - isinstance(self.__class__.clip_norm_var, Variable)): + @classmethod + def check_init(cls): + if not (isinstance(cls.global_norm_var, Variable) and + isinstance(cls.clip_norm_var, Variable)): raise ValueError( - "Class 'GradientClipByGlobalNorm' has not been properly initialized. Please call GradientClipByGlobalNorm.init() first." - ) + "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ + Please call GradientClipByGlobalNorm.init() first.") + + @classmethod + def process_context(cls, context, param, grad): + cls.check_init() - def process_context(self, context, param, grad): local_norm_var = layers.reduce_sum( x=layers.pow(x=grad, factor=2), reduce_all=True) layers.sums( - input=[local_norm_var, self.__class__.global_norm_var], - out=[self.__class__.global_norm_var]) + input=[local_norm_var, cls.global_norm_var], + out=[cls.global_norm_var]) - def create_operators(self, param, grad): - if self.__class__.ratio_var is None: - self.__class__.global_norm_var = layers.sqrt( - x=self.__class__.global_norm_var) - self.__class__.ratio_var = layers.elementwise_div( - x=self.__class__.clip_norm_var, + @classmethod + def create_operators(cls, param, grad): + cls.check_init() + + if cls.scale_var is None: + cls.global_norm_var = layers.sqrt(x=cls.global_norm_var) + cls.scale_var = layers.elementwise_div( + x=cls.clip_norm_var, y=layers.elementwise_max( - x=self.__class__.clip_norm_var, - y=self.__class__.global_norm_var)) - # 缺乏elementwise_max - # 没法将ratio_var送给scale_op。 - # new_grad = layers. + x=cls.clip_norm_var, y=cls.global_norm_var)) + new_grad = layers.elementwise_mul(x=grad, y=cls.scale_var) + return param, new_grad def append_gradient_clip_ops(param_grad): From 4cb6e72b85fef0205a3d3ebfd136e11c009e39f6 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 18:43:54 +0800 Subject: [PATCH 04/14] refine code details --- python/paddle/v2/fluid/clip.py | 18 +++++++-------- python/paddle/v2/fluid/framework.py | 2 +- python/paddle/v2/fluid/param_attr.py | 22 +++++++++---------- .../tests/book/test_recognize_digits_mlp.py | 18 +++++++-------- 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index d8240dc1557f0..f7917fc1423d8 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -1,16 +1,16 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools import layers from framework import Variable @@ -162,7 +162,7 @@ def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] for p, g in param_grad: - clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) + clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) if clip_attr is None: clip_attr = NullGradientClipAttr() if not isinstance(clip_attr, BaseGradientClipAttr): diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 8042febfed7ed..9128a0eebeb25 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -946,7 +946,7 @@ def __init__(self, block, shape, dtype, **kwargs): self.regularizer = kwargs.get('regularizer', None) - self.clip_attr = kwargs.get('clip_attr', None) + self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None) # program is a global instance. diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 3af0190590e77..8c8de0d104664 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -1,16 +1,16 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from initializer import Initializer, Xavier, Constant from regularizer import WeightDecayRegularizer @@ -24,13 +24,13 @@ def __init__(self, learning_rate=1.0, regularizer=None, trainable=True, - clip=None): + gradient_clip=None): self.name = name self.initializer = initializer self.learning_rate = learning_rate self.regularizer = regularizer self.trainable = trainable - self.clip = clip + self.gradient_clip = gradient_clip def set_default_initializer(self, initializer): if initializer is None: @@ -76,7 +76,7 @@ def to_kwargs(self, with_initializer=False): }, 'regularizer': self.regularizer, 'trainable': self.trainable, - 'clip_attr': self.clip + 'gradient_clip_attr': self.gradient_clip } if with_initializer: kwargs['initializer'] = self.initializer diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index 02da2fcc8544d..e614e5e3f134b 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -1,16 +1,16 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import print_function import numpy as np import paddle.v2 as paddle @@ -26,7 +26,7 @@ act='relu', param_attr=fluid.ParamAttr( regularizer=regularizer, - clip=fluid.clip.ClipByValue(10))) + gradient_clip=fluid.clip.ClipByValue(10))) hidden2 = fluid.layers.fc(input=hidden1, size=64, From 6ebfade465be5526939b52b0d251486298c4c734 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 19:14:05 +0800 Subject: [PATCH 05/14] fix copyright information --- paddle/gserver/tests/sequence_recurrent_group.py | 13 +++++++++++++ .../paddle/v2/fluid/tests/test_edit_distance_op.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/paddle/gserver/tests/sequence_recurrent_group.py b/paddle/gserver/tests/sequence_recurrent_group.py index a1d54542e3bc4..1343f2956f397 100644 --- a/paddle/gserver/tests/sequence_recurrent_group.py +++ b/paddle/gserver/tests/sequence_recurrent_group.py @@ -1,3 +1,16 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. #!/usr/bin/env python # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved # diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index 38e87728b387b..cf118df634bb8 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -1,3 +1,16 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. import unittest import numpy as np from op_test import OpTest From 1dac173b518faeb8f31c321a61fa287b8de4246e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 20:15:03 +0800 Subject: [PATCH 06/14] add API for clip_by_global_norm --- python/paddle/v2/fluid/clip.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index f7917fc1423d8..d1e6987e01885 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import layers -from framework import Variable +import framework from . import core __all__ = [ @@ -128,8 +128,8 @@ def init(cls, clip_norm): @classmethod def check_init(cls): - if not (isinstance(cls.global_norm_var, Variable) and - isinstance(cls.clip_norm_var, Variable)): + if not (isinstance(cls.global_norm_var, framework.Variable) and + isinstance(cls.clip_norm_var, framework.Variable)): raise ValueError( "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ Please call GradientClipByGlobalNorm.init() first.") @@ -158,6 +158,23 @@ def create_operators(cls, param, grad): return param, new_grad +def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None): + if program is None: + program = framework.default_main_program() + if param_list is None: + param_list = program.block(0).all_parameters() + if all(isinstance(elem, basestring) for elem in param_list): + param_list = [program.block(0).var(elem) for elem in param_list] + if not all(isinstance(elem, framework.Parameter) for elem in param_list): + raise TypeError( + "'param_list' should be a list of Parameter or basestring(parameter's name)." + ) + + GradientClipByGlobalNorm.init(clip_norm) + for param in param_list: + param.gradient_clip_attr = GradientClipByGlobalNorm() + + def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] From 958d07bee3343288f9813693b5a85150a5131cdd Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 20:21:05 +0800 Subject: [PATCH 07/14] fix a error --- python/paddle/v2/fluid/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 9128a0eebeb25..91fdb5fa7e4ec 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -777,7 +777,7 @@ def copy_param_info_from(self, other): trainable=p.trainable, optimize_attr=p.optimize_attr, regularizer=p.regularizer, - clip_attr=p.clip_attr, + gradient_clip_attr=p.gradient_clip_attr, error_clip=p.error_clip, name=v.name) self.vars[new_p.name] = new_p From a247972ddad05490a7b72911521bff0b48cf2d1c Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 20:31:05 +0800 Subject: [PATCH 08/14] fix a error --- python/paddle/v2/fluid/clip.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index d1e6987e01885..7a36df0dabbca 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -134,8 +134,8 @@ def check_init(cls): "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ Please call GradientClipByGlobalNorm.init() first.") - @classmethod - def process_context(cls, context, param, grad): + def process_context(self, context, param, grad): + cls = self.__class__ cls.check_init() local_norm_var = layers.reduce_sum( @@ -144,8 +144,8 @@ def process_context(cls, context, param, grad): input=[local_norm_var, cls.global_norm_var], out=[cls.global_norm_var]) - @classmethod - def create_operators(cls, param, grad): + def create_operators(self, param, grad): + cls = self.__class__ cls.check_init() if cls.scale_var is None: From 773f2f735c235afcc6ea40ddc2af23fe7a69a2e9 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 18 Jan 2018 21:06:51 +0800 Subject: [PATCH 09/14] fix errors --- python/paddle/v2/fluid/clip.py | 5 +++-- python/paddle/v2/fluid/layers/ops.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 7a36df0dabbca..d4f025a4af60d 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -138,8 +138,7 @@ def process_context(self, context, param, grad): cls = self.__class__ cls.check_init() - local_norm_var = layers.reduce_sum( - x=layers.pow(x=grad, factor=2), reduce_all=True) + local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) layers.sums( input=[local_norm_var, cls.global_norm_var], out=[cls.global_norm_var]) @@ -154,6 +153,8 @@ def create_operators(self, param, grad): x=cls.clip_norm_var, y=layers.elementwise_max( x=cls.clip_norm_var, y=cls.global_norm_var)) + assert cls.scale_var.shape == (1L, ) + new_grad = layers.elementwise_mul(x=grad, y=cls.scale_var) return param, new_grad diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index dd3197fc0029b..a2055c5d7b844 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -48,7 +48,7 @@ 'mean', 'mul', 'reshape', 'scale', 'transpose', 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min', - 'clip', 'clip_by_norm', 'sequence_softmax', 'reduce_sum' + 'clip', 'clip_by_norm', 'sequence_softmax' ] + __activations__ for _OP in set(__all__): From 42b0748ab4f797902cadad4b5278a4cb9fdea9bd Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 19 Jan 2018 15:12:39 +0800 Subject: [PATCH 10/14] add unittest --- python/paddle/v2/fluid/clip.py | 14 +++- .../{test_clip.py => test_error_clip.py} | 0 .../v2/fluid/tests/test_gradient_clip.py | 82 +++++++++++++++++++ 3 files changed, 93 insertions(+), 3 deletions(-) rename python/paddle/v2/fluid/tests/{test_clip.py => test_error_clip.py} (100%) create mode 100644 python/paddle/v2/fluid/tests/test_gradient_clip.py diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index d4f025a4af60d..f6ff83924f251 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -113,6 +113,7 @@ def create_operators(self, param, grad): class GradientClipByGlobalNorm(BaseGradientClipAttr): global_norm_var = None + local_norm_var = None clip_norm_var = None scale_var = None @@ -123,12 +124,18 @@ def init(cls, clip_norm): cls.global_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=0.0) + cls.local_norm_var = framework.default_main_program().current_block( + ).create_var( + name=framework.unique_name("local_norm"), + dtype="float32", + persistable=False) cls.clip_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=clip_norm) @classmethod def check_init(cls): if not (isinstance(cls.global_norm_var, framework.Variable) and + isinstance(cls.local_norm_var, framework.Variable) and isinstance(cls.clip_norm_var, framework.Variable)): raise ValueError( "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ @@ -138,9 +145,10 @@ def process_context(self, context, param, grad): cls = self.__class__ cls.check_init() - local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) + cls.local_norm_var = layers.reduce_sum( + input=layers.pow(x=grad, factor=2.0)) layers.sums( - input=[local_norm_var, cls.global_norm_var], + input=[cls.local_norm_var, cls.global_norm_var], out=[cls.global_norm_var]) def create_operators(self, param, grad): @@ -148,7 +156,7 @@ def create_operators(self, param, grad): cls.check_init() if cls.scale_var is None: - cls.global_norm_var = layers.sqrt(x=cls.global_norm_var) + layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var) cls.scale_var = layers.elementwise_div( x=cls.clip_norm_var, y=layers.elementwise_max( diff --git a/python/paddle/v2/fluid/tests/test_clip.py b/python/paddle/v2/fluid/tests/test_error_clip.py similarity index 100% rename from python/paddle/v2/fluid/tests/test_clip.py rename to python/paddle/v2/fluid/tests/test_error_clip.py diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py new file mode 100644 index 0000000000000..4fb7f0b2cb3f6 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -0,0 +1,82 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + + +def _get_global_param_norm_(params_grads): + res = fluid.layers.fill_constant(shape=[1], dtype="float32", value=0.0) + for _, grad in params_grads: + norm_var = fluid.layers.reduce_sum( + input=fluid.layers.pow(x=grad, factor=2.0)) + fluid.layers.sums(input=[norm_var, res], out=[res]) + fluid.layers.sqrt(x=res, out=res) + return res + + +BATCH_SIZE = 128 +CLIP = 0.5 +prog = fluid.framework.Program() + +with fluid.program_guard(main_program=prog): + image = fluid.layers.data(name='x', shape=[784], dtype='float32') + + hidden1 = fluid.layers.fc(input=image, size=128, act='relu') + hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') + predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') + + label = fluid.layers.data(name='y', shape=[1], dtype='int64') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + +prog_clip = prog.clone() + +avg_cost_clip = prog_clip.block(0).var(avg_cost.name) + +p_g = fluid.backward.append_backward(loss=avg_cost) +p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) + +with fluid.program_guard(main_program=prog): + gloabl_norm = _get_global_param_norm_(p_g) + +with fluid.program_guard(main_program=prog_clip): + fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) + p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) + gloabl_norm_clip = _get_global_param_norm_(p_g_clip) + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) +feeder = fluid.DataFeeder(feed_list=[image, label], place=place) +exe.run(fluid.default_startup_program()) + +count = 0 +for data in train_reader(): + count += 1 + if count > 5: + break + out, = exe.run(prog, feed=feeder.feed(data), fetch_list=[gloabl_norm]) + out_clip, = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=[gloabl_norm_clip]) + + if not np.allclose(out_clip, np.minimum(out, np.array([CLIP]))): + exit(1) +exit(0) From 408a6b8bb2af4f8f075680bb361daad329ad6eca Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 19 Jan 2018 15:17:35 +0800 Subject: [PATCH 11/14] tiny fix --- python/paddle/v2/fluid/clip.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index f6ff83924f251..9800ad7c5d024 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -124,11 +124,11 @@ def init(cls, clip_norm): cls.global_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=0.0) - cls.local_norm_var = framework.default_main_program().current_block( - ).create_var( - name=framework.unique_name("local_norm"), - dtype="float32", - persistable=False) + cls.local_norm_var = framework.default_main_program().block( + 0).create_var( + name=framework.unique_name("local_norm"), + dtype="float32", + persistable=False) cls.clip_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=clip_norm) From 538f1ad28f766c0e47ef4eef2ec59e187ba30f8e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 19 Jan 2018 15:24:14 +0800 Subject: [PATCH 12/14] tiny fix --- python/paddle/v2/fluid/clip.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 9800ad7c5d024..d97cd9ecc936e 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -124,11 +124,7 @@ def init(cls, clip_norm): cls.global_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=0.0) - cls.local_norm_var = framework.default_main_program().block( - 0).create_var( - name=framework.unique_name("local_norm"), - dtype="float32", - persistable=False) + cls.local_norm_var = layers.create_tensor(dtype="float32") cls.clip_norm_var = layers.fill_constant( shape=[1], dtype="float32", value=clip_norm) From 19c554f9e4ef5c96e47f65efd44e2524417e38d7 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 19 Jan 2018 19:19:35 +0800 Subject: [PATCH 13/14] update --- python/paddle/v2/fluid/clip.py | 82 +++++++++---------- .../v2/fluid/tests/test_gradient_clip.py | 44 +++++----- 2 files changed, 59 insertions(+), 67 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index d97cd9ecc936e..fb0907c9f4a74 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -112,58 +112,52 @@ def create_operators(self, param, grad): class GradientClipByGlobalNorm(BaseGradientClipAttr): - global_norm_var = None - local_norm_var = None - clip_norm_var = None - scale_var = None - - @classmethod - def init(cls, clip_norm): - if not (isinstance(clip_norm, int) or isinstance(clip_norm, float)): - raise TypeError("The 'clip_norm' must be a value of int or float") - - cls.global_norm_var = layers.fill_constant( - shape=[1], dtype="float32", value=0.0) - cls.local_norm_var = layers.create_tensor(dtype="float32") - cls.clip_norm_var = layers.fill_constant( - shape=[1], dtype="float32", value=clip_norm) - - @classmethod - def check_init(cls): - if not (isinstance(cls.global_norm_var, framework.Variable) and - isinstance(cls.local_norm_var, framework.Variable) and - isinstance(cls.clip_norm_var, framework.Variable)): - raise ValueError( - "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ - Please call GradientClipByGlobalNorm.init() first.") + def __init__(self, clip_norm, group_name="default_group"): + if not isinstance(group_name, basestring): + raise TypeError("'group_name' must be a basestring.") + + self.clip_norm = clip_norm + self.group_name = group_name def process_context(self, context, param, grad): - cls = self.__class__ - cls.check_init() + if self.group_name not in context: + context[self.group_name] = [] + context[self.group_name + "_clip_value"] = self.clip_norm + context[self.group_name + "_clip"] = layers.fill_constant( + shape=[1], dtype="float32", value=self.clip_norm) + else: + if not self.clip_norm == context[self.group_name + "_clip_value"]: + raise ValueError( + "All parameters' 'clip_norm' of a same group should be the same" + ) - cls.local_norm_var = layers.reduce_sum( - input=layers.pow(x=grad, factor=2.0)) - layers.sums( - input=[cls.local_norm_var, cls.global_norm_var], - out=[cls.global_norm_var]) + local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) + context[self.group_name].append(local_norm_var) - def create_operators(self, param, grad): - cls = self.__class__ - cls.check_init() + self.context = context - if cls.scale_var is None: - layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var) - cls.scale_var = layers.elementwise_div( - x=cls.clip_norm_var, + def create_operators(self, param, grad): + group_scale_name = self.group_name + "_scale" + if group_scale_name not in self.context: + group_norm_var = layers.sums(input=self.context[self.group_name]) + layers.sqrt(x=group_norm_var, out=group_norm_var) + clip_var = self.context[self.group_name + "_clip"] + group_scale_var = layers.elementwise_div( + x=clip_var, y=layers.elementwise_max( - x=cls.clip_norm_var, y=cls.global_norm_var)) - assert cls.scale_var.shape == (1L, ) + x=clip_var, y=group_norm_var)) + assert group_scale_var.shape == (1L, ) + self.context[group_scale_name] = group_scale_var - new_grad = layers.elementwise_mul(x=grad, y=cls.scale_var) + new_grad = layers.elementwise_mul( + x=grad, y=self.context[group_scale_name]) return param, new_grad -def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None): +def gradient_clip_by_global_norm(clip_norm, + param_list=None, + group_name="default_group", + program=None): if program is None: program = framework.default_main_program() if param_list is None: @@ -175,9 +169,9 @@ def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None): "'param_list' should be a list of Parameter or basestring(parameter's name)." ) - GradientClipByGlobalNorm.init(clip_norm) for param in param_list: - param.gradient_clip_attr = GradientClipByGlobalNorm() + param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, + group_name) def append_gradient_clip_ops(param_grad): diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py index 4fb7f0b2cb3f6..75c5fd98925b6 100644 --- a/python/paddle/v2/fluid/tests/test_gradient_clip.py +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -15,21 +15,10 @@ import paddle.v2 as paddle import paddle.v2.fluid as fluid - -def _get_global_param_norm_(params_grads): - res = fluid.layers.fill_constant(shape=[1], dtype="float32", value=0.0) - for _, grad in params_grads: - norm_var = fluid.layers.reduce_sum( - input=fluid.layers.pow(x=grad, factor=2.0)) - fluid.layers.sums(input=[norm_var, res], out=[res]) - fluid.layers.sqrt(x=res, out=res) - return res - - BATCH_SIZE = 128 -CLIP = 0.5 -prog = fluid.framework.Program() +CLIP = 1 +prog = fluid.framework.Program() with fluid.program_guard(main_program=prog): image = fluid.layers.data(name='x', shape=[784], dtype='float32') @@ -49,13 +38,12 @@ def _get_global_param_norm_(params_grads): p_g = fluid.backward.append_backward(loss=avg_cost) p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) -with fluid.program_guard(main_program=prog): - gloabl_norm = _get_global_param_norm_(p_g) - with fluid.program_guard(main_program=prog_clip): fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) - gloabl_norm_clip = _get_global_param_norm_(p_g_clip) + +grad_list = [elem[1] for elem in p_g] +grad_clip_list = [elem[1] for elem in p_g_clip] train_reader = paddle.batch( paddle.reader.shuffle( @@ -72,11 +60,21 @@ def _get_global_param_norm_(params_grads): count += 1 if count > 5: break - out, = exe.run(prog, feed=feeder.feed(data), fetch_list=[gloabl_norm]) - out_clip, = exe.run(prog_clip, - feed=feeder.feed(data), - fetch_list=[gloabl_norm_clip]) - - if not np.allclose(out_clip, np.minimum(out, np.array([CLIP]))): + out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) + out_clip = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=grad_clip_list) + global_norm = 0 + for v in out[1:]: + global_norm += np.sum(np.power(v, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in out_clip[1:]: + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + + if not np.isclose( + a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3): exit(1) exit(0) From e8adcaf27855e452dea5b2deaddb830363cd3964 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 22 Jan 2018 12:50:45 +0800 Subject: [PATCH 14/14] update --- python/paddle/v2/fluid/clip.py | 1 + python/paddle/v2/fluid/layers/ops.py | 1 + python/paddle/v2/fluid/param_attr.py | 1 + python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py | 1 + python/paddle/v2/fluid/tests/test_gradient_clip.py | 1 + 5 files changed, 5 insertions(+) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 777a39e105c0e..386df9823de91 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import layers import framework diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 7e52dc4c34f29..d296076162669 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from ..registry import register_layer __activations__ = [ diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 17fcb262efb3d..dcca8b6c547d1 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from initializer import Initializer, Xavier, Constant from regularizer import WeightDecayRegularizer diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index 2fde3707da35a..8776a65bf804e 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import print_function import numpy as np import paddle.v2 as paddle diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py index 75c5fd98925b6..4e6e6a1ef6961 100644 --- a/python/paddle/v2/fluid/tests/test_gradient_clip.py +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np import paddle.v2 as paddle import paddle.v2.fluid as fluid