diff --git a/src/transformers_neuronx/module.py b/src/transformers_neuronx/module.py index faf0f1e5..80bafe86 100644 --- a/src/transformers_neuronx/module.py +++ b/src/transformers_neuronx/module.py @@ -25,7 +25,7 @@ warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.modules.lazy') def save_pretrained_split(model, save_directory): - model.save_pretrained(save_directory, save_function=save_split, max_shard_size='10000GB') + model.save_pretrained(save_directory, save_function=save_split, max_shard_size='10000GB', safe_serialization=False) _KEY_TO_FILENAME_JSON = 'key_to_filename.json' @@ -160,4 +160,4 @@ def load_state_dict_dir(self, state_dict_dir): self.chkpt_model.load_state_dict_dir(state_dict_dir) def load_state_dict_low_memory(self, state_dict): - self.chkpt_model.load_state_dict_low_memory(state_dict) \ No newline at end of file + self.chkpt_model.load_state_dict_low_memory(state_dict)