Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for quantized OPT models and refactor #295

Merged
merged 13 commits into from
Mar 14, 2023
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ Optionally, you can use the following command-line flags:
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
Expand Down
38 changes: 25 additions & 13 deletions modules/quantized_LLaMA.py → modules/GPTQ_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,40 @@
import modules.shared as shared

sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
from llama import load_quant
import llama
import opt


# 4-bit LLaMA
def load_quantized_LLaMA(model_name):
if shared.args.load_in_4bit:
bits = 4
def load_quantized(model_name):
if not shared.args.gptq_model_type:
# Try to determine model type from model name
model_type = model_name.split('-')[0].lower()
if model_type not in ('llama', 'opt'):
print("Can't determine model type from model name. Please specify it manually using --gptq-model-type "
"argument")
exit()
else:
bits = shared.args.gptq_bits
model_type = shared.args.gptq_model_type.lower()

if model_type == 'llama':
load_quant = llama.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
exit()

path_to_model = Path(f'models/{model_name}')
pt_model = ''
if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{bits}bit.pt'
pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-13b'):
pt_model = f'llama-13b-{bits}bit.pt'
pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-30b'):
pt_model = f'llama-30b-{bits}bit.pt'
pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt'
elif path_to_model.name.lower().startswith('llama-65b'):
pt_model = f'llama-65b-{bits}bit.pt'
pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt'
else:
pt_model = f'{model_name}-{bits}bit.pt'
pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'

# Try to find the .pt both in models/ and in the subfolder
pt_path = None
Expand All @@ -40,7 +52,7 @@ def load_quantized_LLaMA(model_name):
print(f"Could not find {pt_model}, exiting...")
exit()

model = load_quant(str(path_to_model), str(pt_path), bits)
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)

# Multiple GPUs or GPU+CPU
if shared.args.gpu_memory:
Expand Down
12 changes: 6 additions & 6 deletions modules/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import sys
import time
import zipfile
from pathlib import Path
Expand Down Expand Up @@ -35,14 +34,15 @@
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration


def load_model(model_name):
print(f"Loading {model_name}...")
t0 = time.time()

shared.is_RWKV = model_name.lower().startswith('rwkv-')

# Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reluctant to remove --load-in-4bit because that will certainly cause confusion, but I guess we can do it and move on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's up to you but I think that it's a bad practice to keep multiple arguments which do the same thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this --load-in-4bit argument was confusing because it works differently from --load-in-8bit

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have convinced me, let's ditch --load-in-4bit.

if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
Expand Down Expand Up @@ -87,11 +87,11 @@ def load_model(model_name):

return model, tokenizer

# 4-bit LLaMA
elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit:
from modules.quantized_LLaMA import load_quantized_LLaMA
# Quantized model
elif shared.args.gptq_bits > 0:
from modules.GPTQ_loader import load_quantized

model = load_quantized_LLaMA(model_name)
model = load_quantized(model_name)

# Custom
else:
Expand Down
10 changes: 8 additions & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def str2bool(v):
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.')
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.')
parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.')
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.')
parser.add_argument('--gptq-model-type', type=str, help='Model type of pre-quantized model. Currently only LLaMa and OPT are supported.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
Expand All @@ -95,3 +96,8 @@ def str2bool(v):
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
args = parser.parse_args()

# Provisional, this will be deleted later
if args.load_in_4bit:
print("Warning: --load-in-4bit is deprecated and will be removed. Use --gptq-bits 4 instead.\n")
args.gptq_bits = 4