From 59da429cbd7ece3cf1d0f03ce5493baf5653e4a2 Mon Sep 17 00:00:00 2001 From: "FartyPants (FP HAM)" Date: Sun, 17 Dec 2023 21:54:06 -0500 Subject: [PATCH] Update Training PRO (#4972) - rolling back safetensors to bi, until it is fixed correctly - removing the ugly checkpoint detour --- extensions/Training_PRO/script.py | 90 ++++--------------------------- 1 file changed, 11 insertions(+), 79 deletions(-) diff --git a/extensions/Training_PRO/script.py b/extensions/Training_PRO/script.py index 5afa627e6b..8f29646232 100644 --- a/extensions/Training_PRO/script.py +++ b/extensions/Training_PRO/script.py @@ -51,59 +51,9 @@ from modules.models import reload_model from modules.utils import natural_keys - - -## just temporary to avoid warning - -import inspect - -from typing import Callable, Optional, Tuple, ContextManager - - - -if hasattr(torch.utils.checkpoint, 'noop_context_fn'): - def my_checkpoint( - function, - *args, - use_reentrant: Optional[bool] = None, - context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = torch.utils.checkpoint.noop_context_fn, - determinism_check: str = torch.utils.checkpoint._DEFAULT_DETERMINISM_MODE, - debug: bool = False, - **kwargs - ): - - if use_reentrant is None: - #print ("reentran = NONE") - use_reentrant = True - # Hack to mix *args with **kwargs in a python 2.7-compliant way - preserve = kwargs.pop("preserve_rng_state", True) - if kwargs and use_reentrant: - raise ValueError( - "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) - ) - - if use_reentrant: - if context_fn is not torch.utils.checkpoint.noop_context_fn or debug is not False: - raise ValueError( - "Passing `context_fn` or `debug` is only supported when " - "use_reentrant=False." - ) - return torch.utils.checkpoint.CheckpointFunction.apply(function, preserve, *args) - else: - - print ("reentran = FALSE") - gen = torch.utils.checkpoint._checkpoint_without_reentrant_generator( - function, preserve, context_fn, determinism_check, debug, *args, **kwargs - ) - # Runs pre-forward logic - next(gen) - ret = function(*args, **kwargs) - # Runs post-forward logic - try: - next(gen) - except StopIteration: - return ret - +import warnings +warnings.filterwarnings(action = "ignore", message="torch.utils.checkpoint:") +warnings.filterwarnings(action = "ignore", message="`do_sample` is set to `False`") params = { "display_name": "Training PRO", @@ -121,6 +71,7 @@ def my_checkpoint( "save_epochs": 0, "checkpoint_offset": 0, "epoch_offset":0, + "safe_serialization": False, } MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()} @@ -150,7 +101,7 @@ def ui(): with gr.Row(): with gr.Column(): # YY.MM.DD - gr.Markdown("`Ver: 23.10.20` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)") + gr.Markdown("`Ver: 23.10.20 (REV2)` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)") with gr.Row(): with gr.Column(scale=5): @@ -290,7 +241,7 @@ def ui(): stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') with gr.Column(): - max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') + max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') with gr.Row(): start_current_evaluation = gr.Button("Evaluate loaded model") @@ -712,7 +663,6 @@ def tokenize(prompt, append_eos_token=False, prepend_bos_token = False): } train_template.clear() - #reset stuff print(f"*** LoRA: {lora_name} ***") @@ -725,26 +675,8 @@ def tokenize(prompt, append_eos_token=False, prepend_bos_token = False): non_serialized_params.update({"checkpoint_offset": 0}) non_serialized_params.update({"epoch_offset": 0}) train_log_graph.clear() - - # === once fixed, this can be removed ============================== - if hasattr(torch.utils.checkpoint, 'noop_context_fn'): - print("Testing Pytorch...") - old_checkpoint_signature = inspect.signature(torch.utils.checkpoint.checkpoint) - - # Get the signature of your new checkpoint function - my_checkpoint_signature = inspect.signature(my_checkpoint) - - # Check if the signatures match - if old_checkpoint_signature.parameters == my_checkpoint_signature.parameters: - print(F"{RED}Overriding Torch checkpoint function to avoid repeated 'use_reentrant not explicitly set' warnings{RESET}") - #print(" - Note: Transformers need to pass use_reentrant in llama.modeling_llama in def forward, layer_outputs = torch.utils.checkpoint.checkpoint") - #print(" Once they do, this function can be removed") - torch.utils.checkpoint.checkpoint = my_checkpoint - - - # END OF FPHAM SENTENCE SPLIT functions =================== - - # == Prep the dataset, format, etc == + + # == Prep the dataset, format, etc == if raw_text_file not in ['None', '']: train_template["template_type"] = "raw_text" logger.info("Loading text file...") @@ -1025,7 +957,7 @@ def on_step_begin(self, args: transformers.TrainingArguments, state: transformer force_save = True if force_save: - lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/") + lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/", safe_serialization = non_serialized_params['safe_serialization']) print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]") # Save log with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file: @@ -1252,7 +1184,7 @@ def threaded_run(): log_train_dataset(trainer) trainer.train() # Note: save in the thread in case the gradio thread breaks (eg browser closed) - lora_model.save_pretrained(lora_file_path) + lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization']) logger.info("LoRA training run is completed and saved.") # Save log with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: @@ -1353,7 +1285,7 @@ def threaded_run(): if not tracked.did_save: logger.info("Training complete, saving...") - lora_model.save_pretrained(lora_file_path) + lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization']) if WANT_INTERRUPT: logger.info("Training interrupted.")