From 4c6da7d29c0e982928d51b63ad929ada1a7ac3e8 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Wed, 11 Jan 2023 20:24:07 +0800 Subject: [PATCH] support dtype auto convertion in deit load function (#180) --- plsc/models/deit.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plsc/models/deit.py b/plsc/models/deit.py index 66bbf57ae6fba..a9bdf5c6079dd 100644 --- a/plsc/models/deit.py +++ b/plsc/models/deit.py @@ -141,6 +141,14 @@ def load_pretrained(self, path, rank=0, finetune=False): state_dict = self.state_dict() param_state_dict = paddle.load(path + ".pdparams") + + # for FP16 saving pretrained weight + for key, value in param_state_dict.items(): + if key in param_state_dict and key in state_dict and param_state_dict[ + key].dtype != state_dict[key].dtype: + param_state_dict[key] = param_state_dict[key].astype( + state_dict[key].dtype) + if not finetune: self.set_dict(param_state_dict) return