Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add use_flash_attention_2 to param for Model loader Transformers #4373

Merged
merged 8 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ Optionally, you can use the following command-line flags:
| `--sdp-attention` | Use PyTorch 2.0's SDP attention. Same as above. |
| `--trust-remote-code` | Set `trust_remote_code=True` while loading the model. Necessary for some models. |
| `--use_fast` | Set `use_fast=True` while loading the tokenizer. |
| `--use_flash_attention_2` | Set use_flash_attention_2=True while loading the model. |

#### Accelerate 4-bit

Expand Down
2 changes: 1 addition & 1 deletion modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
'Transformers': [
'cpu_memory',
'gpu_memory',
'trust_remote_code',
'load_in_8bit',
'bf16',
'cpu',
Expand All @@ -21,6 +20,7 @@
'compute_dtype',
'trust_remote_code',
'use_fast',
'use_flash_attention_2',
'alpha_value',
'rope_freq_base',
'compress_pos_emb',
Expand Down
4 changes: 4 additions & 0 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def huggingface_loader(model_name):
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
'use_safetensors': True if shared.args.force_safetensors else None
}

if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True

config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])

if 'chatglm' in model_name.lower():
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
parser.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.')
parser.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.')
parser.add_argument('--use_fast', action='store_true', help='Set use_fast=True while loading the tokenizer.')
parser.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.')

# Accelerate 4-bit
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).')
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def list_model_elements():
'load_in_8bit',
'trust_remote_code',
'use_fast',
'use_flash_attention_2',
'load_in_4bit',
'compute_dtype',
'quant_type',
Expand Down
1 change: 1 addition & 0 deletions modules/ui_model_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def create_ui():
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code)
shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.')
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
Expand Down