Skip to content

Commit

Permalink
[AutoParallel] fix fp16 for subblock (PaddlePaddle#47189)
Browse files Browse the repository at this point in the history
* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment
  • Loading branch information
zhaoyinglia authored and zhaoyingli committed Oct 20, 2022
1 parent 9ed1454 commit 9cda2d6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def _infer_item_spec(item, name, batch_size, specs):
else:
specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
_adjust_item_spec(num_shards, spec)
spec = InputSpec.from_tensor(item, name)
_adjust_item_spec(num_shards, spec)
if batch_size is None:
specs.append(spec)
else:
Expand Down Expand Up @@ -1508,10 +1508,10 @@ def load(self, path, strict=True, load_optimizer=True):
strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape). Default: False.
mismatch shape). Default: True.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False.
from scratch. Default: True.
Returns:
None
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/passes/auto_parallel_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def set_var_to_fp16(self, var_name, block):
try:
var = block.var(var_name)
except ValueError as e:
var = self.program.global_block().var(var_name)
var = block._var_recursive(var_name)
# var = self.program.global_block().var(var_name)

# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
Expand Down

0 comments on commit 9cda2d6

Please sign in to comment.