Skip to content

Commit

Permalink
support dtype auto convertion in deit load function (PaddlePaddle#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Jan 11, 2023
1 parent 970d8c3 commit 4c6da7d
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions plsc/models/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def load_pretrained(self, path, rank=0, finetune=False):

state_dict = self.state_dict()
param_state_dict = paddle.load(path + ".pdparams")

# for FP16 saving pretrained weight
for key, value in param_state_dict.items():
if key in param_state_dict and key in state_dict and param_state_dict[
key].dtype != state_dict[key].dtype:
param_state_dict[key] = param_state_dict[key].astype(
state_dict[key].dtype)

if not finetune:
self.set_dict(param_state_dict)
return
Expand Down

0 comments on commit 4c6da7d

Please sign in to comment.