Skip to content

Commit

Permalink
transformers: Add eager attention option to make Gemma-2 work properly (
Browse files Browse the repository at this point in the history
  • Loading branch information
GralchemOz authored and PoetOnTheRun committed Oct 22, 2024
1 parent 2b60365 commit 33cc62d
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'trust_remote_code',
'no_use_fast',
'use_flash_attention_2',
'use_eager_attention',
'alpha_value',
'compress_pos_emb',
'disable_exllama',
Expand Down
3 changes: 3 additions & 0 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def huggingface_loader(model_name):
if shared.args.force_safetensors:
params['force_safetensors'] = True

if shared.args.use_eager_attention:
params['attn_implementation'] = 'eager'

config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)

if 'chatglm' in model_name.lower():
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
group.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.')
group.add_argument('--no_use_fast', action='store_true', help='Set use_fast=False while loading the tokenizer (it\'s True by default). Use this if you have any problems related to use_fast.')
group.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.')
group.add_argument('--use_eager_attention', action='store_true', help='Set attn_implementation= eager while loading the model.')

# bitsandbytes 4-bit
group = parser.add_argument_group('bitsandbytes 4-bit')
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def list_model_elements():
'trust_remote_code',
'no_use_fast',
'use_flash_attention_2',
'use_eager_attention',
'load_in_4bit',
'compute_dtype',
'quant_type',
Expand Down
1 change: 1 addition & 0 deletions modules/ui_model_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create_ui():
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['use_eager_attention'] = gr.Checkbox(label="use_eager_attention", value=shared.args.use_eager_attention, info='Set attn_implementation= eager while loading the model.')
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This increases performance on RTX cards.')
Expand Down

0 comments on commit 33cc62d

Please sign in to comment.