diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index e30a84e12826a..8e88a213b5454 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -236,6 +236,10 @@ def _get_input_varlist(self, program): ret_list.append(var) return ret_list + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def minimize( self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index 87085a322c303..9a7660ebd7dc1 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -25,6 +25,10 @@ def __init__(self, optimizer): self.meta_optimizers_white_list = [] self.meta_optimizers_black_list = [] + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def _set_basic_info( self, loss, role_maker, user_defined_optimizer, user_defined_strategy ): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 00ec12a523f91..639bdf79ac9aa 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -203,6 +203,10 @@ def __init__( # Update optimizer parameters and adjust parameter storage and use according to rank. self._update_opt_status() + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self._optim._set_auxiliary_var(key, val) + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ diff --git a/python/paddle/incubate/optimizer/lookahead.py b/python/paddle/incubate/optimizer/lookahead.py index b1ad5f3ecb0b5..bfa08c40556be 100644 --- a/python/paddle/incubate/optimizer/lookahead.py +++ b/python/paddle/incubate/optimizer/lookahead.py @@ -144,6 +144,10 @@ def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None): self._global_step_var = None self._k_var = None + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_optimizer._set_auxiliary_var(key, val) + @framework.dygraph_only @imperative_base.no_grad def step(self):