Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test thread pool #354

Merged
merged 9 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import shutil
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from typing import List, Optional

Expand Down Expand Up @@ -263,6 +264,26 @@ def select_quant_linear_with_pack(bits: int,
)
return QuantLinear

def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar):
# Limit pack() thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
pbar.set_description(f"Packing {name}")
quantizers[name], scale, zero, g_idx = quantizers[name]
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = (
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear is MarlinQuantLinear:
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
pbar.update()

def pack_model(
model,
quantizers,
Expand Down Expand Up @@ -290,6 +311,7 @@ def pack_model(
model.to(CPU)

logger.info("Packing model...")

layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(
Expand All @@ -304,31 +326,16 @@ def pack_model(
dynamic=dynamic,
)
qlayers = find_layers(model, [QuantLinear])
names = list(qlayers.keys())
with ThreadPoolExecutor(max_workers=2) as executor:
with tqdm(total=len(names), leave=True) as pbar:
def wrapper(name):
pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar)

# Limit pack() thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
pbar = tqdm(qlayers.keys(), leave=True)
for name in pbar:
pbar.set_description(f"Packing {name}")

quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = (
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear is MarlinQuantLinear:
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
for _ in executor.map(wrapper, names):
pass

logger.info("Model packed.")

return QuantLinear

def verify_model_hash(file_path: str, verify_hash: str):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_quant_trust_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def test_diff_batch(self):
del model
py_files = [f for f in os.listdir(tmp_dir) if f.endswith('.py')]
expected_files = ["modeling_minicpm.py", "configuration_minicpm.py"]
self.assertEqual(py_files, expected_files)
for file in expected_files:
self.assertIn(file, py_files, f"File {file} is missing in the actual files list")



Loading