diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 721d10d76b128..1f644147a209b 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -252,14 +252,26 @@ def check_models(models): ) +def _is_valid_optimizer(optimizer): + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) + + return isinstance( + optimizer, + ( + paddle.optimizer.Optimizer, + paddle.fluid.optimizer.Optimizer, + DygraphShardingOptimizer, + ), + ) + + def check_optimizers(optimizers): for optimizer in optimizers: - if not isinstance( - optimizer, - (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer), - ): + if not _is_valid_optimizer(optimizer): raise RuntimeError( - "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format( + "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format( type(optimizer) ) ) @@ -477,6 +489,20 @@ def __call__(self, state_dict): state_dict[key] = param_applied +def _set_multi_precision(optimizer, multi_precision): + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) + + optimizer = ( + optimizer._inner_optimizer + if isinstance(optimizer, DygraphShardingOptimizer) + else optimizer + ) + if hasattr(optimizer, "_multi_precision"): + optimizer._multi_precision = multi_precision + + @dygraph_only def amp_decorate( models, @@ -582,10 +608,7 @@ def amp_decorate( if optimizers is not None: # check optimizers optimizers_is_list = False - if isinstance( - optimizers, - (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer), - ): + if _is_valid_optimizer(optimizers): optimizers_is_list = False optimizers = [optimizers] check_optimizers(optimizers) @@ -596,13 +619,10 @@ def amp_decorate( raise TypeError( "optimizers must be either a single optimizer or a list of optimizers." ) - # supprot master_weight - for idx_opt in range(len(optimizers)): - if hasattr(optimizers[idx_opt], '_multi_precision'): - if master_weight is False: - optimizers[idx_opt]._multi_precision = False - else: - optimizers[idx_opt]._multi_precision = True + # support master_weight + use_multi_precision = not (master_weight is False) + for opt in optimizers: + _set_multi_precision(opt, use_multi_precision) if save_dtype is not None: if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):