From ebed000441068adf3e6d613a485ea672a2239e5a Mon Sep 17 00:00:00 2001 From: Dennj Date: Mon, 27 Nov 2023 21:47:49 +0000 Subject: [PATCH] Added safetensors support in from_pretrained() https://github.com/aws-neuron/transformers-neuronx/issues/60 https://github.com/aws-neuron/aws-neuron-sdk/issues/786 --- src/transformers_neuronx/module.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/transformers_neuronx/module.py b/src/transformers_neuronx/module.py index faf0f1e5..7839f814 100644 --- a/src/transformers_neuronx/module.py +++ b/src/transformers_neuronx/module.py @@ -18,6 +18,7 @@ import warnings import torch +from safetensors import safe_open from torch.nn.parameter import UninitializedParameter from transformers import AutoConfig @@ -137,16 +138,32 @@ class LowMemoryLazyLinear(torch.nn.LazyLinear, LowMemoryModule): ... class PretrainedModel(LowMemoryModule): + @staticmethod + def _safeload(state_dict_path): + state_dict = {} + with safe_open(state_dict_path, framework="pt") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + return state_dict + @classmethod def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs): config = AutoConfig.from_pretrained(pretrained_model_path) model = cls(config, *model_args, **kwargs) state_dict_path = os.path.join(pretrained_model_path, 'pytorch_model.bin') - if os.path.isdir(state_dict_path): + state_dict_safetensor_path = os.path.join(pretrained_model_path, 'model.safetensors') + + if os.path.isfile(state_dict_safetensor_path): + state_dict = PretrainedModel._safeload(state_dict_safetensor_path) + model.load_state_dict_low_memory(state_dict) + elif os.path.isdir(state_dict_path): model.load_state_dict_dir(state_dict_path) - else: + elif os.path.isfile(state_dict_path): state_dict = torch.load(state_dict_path) model.load_state_dict_low_memory(state_dict) + else: + raise FileNotFoundError(f"Can not find model.safetensors or pytorch_model.bin in {pretrained_model_path}") + return model @@ -160,4 +177,4 @@ def load_state_dict_dir(self, state_dict_dir): self.chkpt_model.load_state_dict_dir(state_dict_dir) def load_state_dict_low_memory(self, state_dict): - self.chkpt_model.load_state_dict_low_memory(state_dict) \ No newline at end of file + self.chkpt_model.load_state_dict_low_memory(state_dict)