From f6904635d4610d4c8cb3fc60e4959516e34f511f Mon Sep 17 00:00:00 2001 From: jitto Date: Wed, 8 Nov 2023 19:18:37 -0600 Subject: [PATCH] Turn off safe_serialization so that save_function is called --- src/transformers_neuronx/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers_neuronx/module.py b/src/transformers_neuronx/module.py index faf0f1e5..80bafe86 100644 --- a/src/transformers_neuronx/module.py +++ b/src/transformers_neuronx/module.py @@ -25,7 +25,7 @@ warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.modules.lazy') def save_pretrained_split(model, save_directory): - model.save_pretrained(save_directory, save_function=save_split, max_shard_size='10000GB') + model.save_pretrained(save_directory, save_function=save_split, max_shard_size='10000GB', safe_serialization=False) _KEY_TO_FILENAME_JSON = 'key_to_filename.json' @@ -160,4 +160,4 @@ def load_state_dict_dir(self, state_dict_dir): self.chkpt_model.load_state_dict_dir(state_dict_dir) def load_state_dict_low_memory(self, state_dict): - self.chkpt_model.load_state_dict_low_memory(state_dict) \ No newline at end of file + self.chkpt_model.load_state_dict_low_memory(state_dict)