Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Nov 2, 2023
2 parents 77abd9b + a56ef2a commit 42f8163
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
.DS_Store
.eslintrc.js
.idea
.env
.venv
venv
.vscode
Expand Down
6 changes: 3 additions & 3 deletions models/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@
.*starchat-beta:
instruction_template: 'Starchat-Beta'
custom_stopping_strings: '"<|end|>"'
.*(openorca-platypus2):
instruction_template: 'OpenOrca-Platypus2'
custom_stopping_strings: '"### Instruction:", "### Response:"'
(?!.*v0)(?!.*1.1)(?!.*1_1)(?!.*stable)(?!.*chinese).*vicuna:
instruction_template: 'Vicuna-v0'
.*vicuna.*v0:
Expand Down Expand Up @@ -152,6 +149,9 @@
instruction_template: 'Orca Mini'
.*(platypus|gplatty|superplatty):
instruction_template: 'Alpaca'
.*(openorca-platypus2):
instruction_template: 'OpenOrca-Platypus2'
custom_stopping_strings: '"### Instruction:", "### Response:"'
.*longchat:
instruction_template: 'Vicuna-v1.1'
.*vicuna-33b:
Expand Down
2 changes: 1 addition & 1 deletion modules/GPTQ_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def noop(*args, **kwargs):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False)
else:
model.load_state_dict(torch.load(checkpoint), strict=False)
model.load_state_dict(torch.load(checkpoint, weights_only=True), strict=False)

model.seqlen = 2048
return model
Expand Down
2 changes: 1 addition & 1 deletion modules/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def generate_and_tokenize_prompt(data_point):
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
yield traceback.format_exc().replace('\n', '\n\n')
Expand Down

0 comments on commit 42f8163

Please sign in to comment.