Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Training PRO #4972

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 11 additions & 79 deletions extensions/Training_PRO/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,59 +51,9 @@
from modules.models import reload_model
from modules.utils import natural_keys



## just temporary to avoid warning

import inspect

from typing import Callable, Optional, Tuple, ContextManager



if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
def my_checkpoint(
function,
*args,
use_reentrant: Optional[bool] = None,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = torch.utils.checkpoint.noop_context_fn,
determinism_check: str = torch.utils.checkpoint._DEFAULT_DETERMINISM_MODE,
debug: bool = False,
**kwargs
):

if use_reentrant is None:
#print ("reentran = NONE")
use_reentrant = True
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop("preserve_rng_state", True)
if kwargs and use_reentrant:
raise ValueError(
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
)

if use_reentrant:
if context_fn is not torch.utils.checkpoint.noop_context_fn or debug is not False:
raise ValueError(
"Passing `context_fn` or `debug` is only supported when "
"use_reentrant=False."
)
return torch.utils.checkpoint.CheckpointFunction.apply(function, preserve, *args)
else:

print ("reentran = FALSE")
gen = torch.utils.checkpoint._checkpoint_without_reentrant_generator(
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
)
# Runs pre-forward logic
next(gen)
ret = function(*args, **kwargs)
# Runs post-forward logic
try:
next(gen)
except StopIteration:
return ret

import warnings
warnings.filterwarnings(action = "ignore", message="torch.utils.checkpoint:")
warnings.filterwarnings(action = "ignore", message="`do_sample` is set to `False`")

params = {
"display_name": "Training PRO",
Expand All @@ -121,6 +71,7 @@ def my_checkpoint(
"save_epochs": 0,
"checkpoint_offset": 0,
"epoch_offset":0,
"safe_serialization": False,
}

MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
Expand Down Expand Up @@ -150,7 +101,7 @@ def ui():
with gr.Row():
with gr.Column():
# YY.MM.DD
gr.Markdown("`Ver: 23.10.20` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
gr.Markdown("`Ver: 23.10.20 (REV2)` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")

with gr.Row():
with gr.Column(scale=5):
Expand Down Expand Up @@ -290,7 +241,7 @@ def ui():
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')

with gr.Column():
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')

with gr.Row():
start_current_evaluation = gr.Button("Evaluate loaded model")
Expand Down Expand Up @@ -712,7 +663,6 @@ def tokenize(prompt, append_eos_token=False, prepend_bos_token = False):
}

train_template.clear()


#reset stuff
print(f"*** LoRA: {lora_name} ***")
Expand All @@ -725,26 +675,8 @@ def tokenize(prompt, append_eos_token=False, prepend_bos_token = False):
non_serialized_params.update({"checkpoint_offset": 0})
non_serialized_params.update({"epoch_offset": 0})
train_log_graph.clear()

# === once fixed, this can be removed ==============================
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
print("Testing Pytorch...")
old_checkpoint_signature = inspect.signature(torch.utils.checkpoint.checkpoint)

# Get the signature of your new checkpoint function
my_checkpoint_signature = inspect.signature(my_checkpoint)

# Check if the signatures match
if old_checkpoint_signature.parameters == my_checkpoint_signature.parameters:
print(F"{RED}Overriding Torch checkpoint function to avoid repeated 'use_reentrant not explicitly set' warnings{RESET}")
#print(" - Note: Transformers need to pass use_reentrant in llama.modeling_llama in def forward, layer_outputs = torch.utils.checkpoint.checkpoint")
#print(" Once they do, this function can be removed")
torch.utils.checkpoint.checkpoint = my_checkpoint


# END OF FPHAM SENTENCE SPLIT functions ===================

# == Prep the dataset, format, etc ==

# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text"
logger.info("Loading text file...")
Expand Down Expand Up @@ -1025,7 +957,7 @@ def on_step_begin(self, args: transformers.TrainingArguments, state: transformer
force_save = True

if force_save:
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/")
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/", safe_serialization = non_serialized_params['safe_serialization'])
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]")
# Save log
with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file:
Expand Down Expand Up @@ -1252,7 +1184,7 @@ def threaded_run():
log_train_dataset(trainer)
trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization'])
logger.info("LoRA training run is completed and saved.")
# Save log
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
Expand Down Expand Up @@ -1353,7 +1285,7 @@ def threaded_run():

if not tracked.did_save:
logger.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path)
lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization'])

if WANT_INTERRUPT:
logger.info("Training interrupted.")
Expand Down