From 231d615a58b70552778abb56c8a4e50fc61ca56d Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 04:17:00 +0000 Subject: [PATCH 1/7] test thread pool --- gptqmodel/utils/model.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index ce63373d..3e1cc151 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -263,6 +263,22 @@ def select_quant_linear_with_pack(bits: int, ) return QuantLinear +def pack_layer(name, qlayers, quantizers, layers, QuantLinear): + quantizers[name], scale, zero, g_idx = quantizers[name] + layer_device = qlayers[name].device + qlayers[name].to(CPU) + layers[name], scale, zero, g_idx = ( + layers[name].to(CPU), + scale.to(CPU), + zero.to(CPU), + g_idx.to(CPU), + ) + if QuantLinear is MarlinQuantLinear: + qlayers[name].pack(layers[name], scale) + else: + qlayers[name].pack(layers[name], scale, zero, g_idx) + qlayers[name].to(layer_device) + def pack_model( model, quantizers, @@ -275,6 +291,9 @@ def pack_model( force_layer_back_to_cpu: bool = False, dynamic=None, ): + import time + from concurrent.futures import ThreadPoolExecutor + QuantLinear = select_quant_linear_with_pack( bits=bits, dynamic=dynamic, @@ -290,6 +309,7 @@ def pack_model( model.to(CPU) logger.info("Packing model...") + start = time.time() layers = find_layers(model) layers = {n: layers[n] for n in quantizers} make_quant( @@ -305,6 +325,15 @@ def pack_model( ) qlayers = find_layers(model, [QuantLinear]) + # with ThreadPoolExecutor(max_workers=4) as executor: + # executor.map( + # pack_layer, + # qlayers.keys(), + # [qlayers] * len(qlayers), + # [quantizers] * len(qlayers), + # [layers] * len(qlayers), + # [QuantLinear] * len(qlayers) + # ) # Limit pack() thread usage to avoid auto-parallizataion regression with tctl.threadpool_limits(limits=1): pbar = tqdm(qlayers.keys(), leave=True) @@ -328,7 +357,7 @@ def pack_model( qlayers[name].to(layer_device) logger.info("Model packed.") - + print(f"Time for pack: {time.time() - start}") return QuantLinear def verify_model_hash(file_path: str, verify_hash: str): From b36ca6c5d0d68e0b45457e0ad4a3801a356e7f55 Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 04:22:18 +0000 Subject: [PATCH 2/7] update code --- gptqmodel/utils/model.py | 58 ++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3e1cc151..56dbfb6c 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -325,36 +325,36 @@ def pack_model( ) qlayers = find_layers(model, [QuantLinear]) - # with ThreadPoolExecutor(max_workers=4) as executor: - # executor.map( - # pack_layer, - # qlayers.keys(), - # [qlayers] * len(qlayers), - # [quantizers] * len(qlayers), - # [layers] * len(qlayers), - # [QuantLinear] * len(qlayers) - # ) + with ThreadPoolExecutor(max_workers=4) as executor: + executor.map( + pack_layer, + qlayers.keys(), + [qlayers] * len(qlayers), + [quantizers] * len(qlayers), + [layers] * len(qlayers), + [QuantLinear] * len(qlayers) + ) # Limit pack() thread usage to avoid auto-parallizataion regression - with tctl.threadpool_limits(limits=1): - pbar = tqdm(qlayers.keys(), leave=True) - for name in pbar: - pbar.set_description(f"Packing {name}") - - quantizers[name], scale, zero, g_idx = quantizers[name] - # so far can only pack layer on CPU - layer_device = qlayers[name].device - qlayers[name].to(CPU) - layers[name], scale, zero, g_idx = ( - layers[name].to(CPU), - scale.to(CPU), - zero.to(CPU), - g_idx.to(CPU), - ) - if QuantLinear is MarlinQuantLinear: - qlayers[name].pack(layers[name], scale) - else: - qlayers[name].pack(layers[name], scale, zero, g_idx) - qlayers[name].to(layer_device) + # with tctl.threadpool_limits(limits=1): + # pbar = tqdm(qlayers.keys(), leave=True) + # for name in pbar: + # pbar.set_description(f"Packing {name}") + # + # quantizers[name], scale, zero, g_idx = quantizers[name] + # # so far can only pack layer on CPU + # layer_device = qlayers[name].device + # qlayers[name].to(CPU) + # layers[name], scale, zero, g_idx = ( + # layers[name].to(CPU), + # scale.to(CPU), + # zero.to(CPU), + # g_idx.to(CPU), + # ) + # if QuantLinear is MarlinQuantLinear: + # qlayers[name].pack(layers[name], scale) + # else: + # qlayers[name].pack(layers[name], scale, zero, g_idx) + # qlayers[name].to(layer_device) logger.info("Model packed.") print(f"Time for pack: {time.time() - start}") From 40943097e69463e4a73fdc3b459506377710dc6a Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 05:12:17 +0000 Subject: [PATCH 3/7] update code --- gptqmodel/utils/model.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 56dbfb6c..68a5fca5 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -264,20 +264,21 @@ def select_quant_linear_with_pack(bits: int, return QuantLinear def pack_layer(name, qlayers, quantizers, layers, QuantLinear): - quantizers[name], scale, zero, g_idx = quantizers[name] - layer_device = qlayers[name].device - qlayers[name].to(CPU) - layers[name], scale, zero, g_idx = ( - layers[name].to(CPU), - scale.to(CPU), - zero.to(CPU), - g_idx.to(CPU), - ) - if QuantLinear is MarlinQuantLinear: - qlayers[name].pack(layers[name], scale) - else: - qlayers[name].pack(layers[name], scale, zero, g_idx) - qlayers[name].to(layer_device) + with tctl.threadpool_limits(limits=1): + quantizers[name], scale, zero, g_idx = quantizers[name] + layer_device = qlayers[name].device + qlayers[name].to(CPU) + layers[name], scale, zero, g_idx = ( + layers[name].to(CPU), + scale.to(CPU), + zero.to(CPU), + g_idx.to(CPU), + ) + if QuantLinear is MarlinQuantLinear: + qlayers[name].pack(layers[name], scale) + else: + qlayers[name].pack(layers[name], scale, zero, g_idx) + qlayers[name].to(layer_device) def pack_model( model, From 1f270c13b9414cf337d16f9818e7496fa2001b73 Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 05:19:20 +0000 Subject: [PATCH 4/7] mod start_time line --- gptqmodel/utils/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 68a5fca5..5de2ff55 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -310,7 +310,7 @@ def pack_model( model.to(CPU) logger.info("Packing model...") - start = time.time() + layers = find_layers(model) layers = {n: layers[n] for n in quantizers} make_quant( @@ -325,7 +325,7 @@ def pack_model( dynamic=dynamic, ) qlayers = find_layers(model, [QuantLinear]) - + start = time.time() with ThreadPoolExecutor(max_workers=4) as executor: executor.map( pack_layer, From 44d9b3ddc996fcd3de90e06e8d5fd1c8769751c3 Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 09:05:02 +0000 Subject: [PATCH 5/7] add tqdm --- gptqmodel/utils/model.py | 49 +++++++++++----------------------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 5de2ff55..fcab84c8 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -17,6 +17,7 @@ from tqdm import tqdm from transformers import AutoConfig, PretrainedConfig from transformers.utils.hub import cached_file +from concurrent.futures import ThreadPoolExecutor from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS from ..nn_modules.qlinear import BaseQuantLinear @@ -263,8 +264,10 @@ def select_quant_linear_with_pack(bits: int, ) return QuantLinear -def pack_layer(name, qlayers, quantizers, layers, QuantLinear): +def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar): + # Limit pack() thread usage to avoid auto-parallizataion regression with tctl.threadpool_limits(limits=1): + pbar.set_description(f"Packing {name}") quantizers[name], scale, zero, g_idx = quantizers[name] layer_device = qlayers[name].device qlayers[name].to(CPU) @@ -279,6 +282,7 @@ def pack_layer(name, qlayers, quantizers, layers, QuantLinear): else: qlayers[name].pack(layers[name], scale, zero, g_idx) qlayers[name].to(layer_device) + pbar.update() def pack_model( model, @@ -292,9 +296,6 @@ def pack_model( force_layer_back_to_cpu: bool = False, dynamic=None, ): - import time - from concurrent.futures import ThreadPoolExecutor - QuantLinear = select_quant_linear_with_pack( bits=bits, dynamic=dynamic, @@ -325,40 +326,16 @@ def pack_model( dynamic=dynamic, ) qlayers = find_layers(model, [QuantLinear]) - start = time.time() - with ThreadPoolExecutor(max_workers=4) as executor: - executor.map( - pack_layer, - qlayers.keys(), - [qlayers] * len(qlayers), - [quantizers] * len(qlayers), - [layers] * len(qlayers), - [QuantLinear] * len(qlayers) - ) - # Limit pack() thread usage to avoid auto-parallizataion regression - # with tctl.threadpool_limits(limits=1): - # pbar = tqdm(qlayers.keys(), leave=True) - # for name in pbar: - # pbar.set_description(f"Packing {name}") - # - # quantizers[name], scale, zero, g_idx = quantizers[name] - # # so far can only pack layer on CPU - # layer_device = qlayers[name].device - # qlayers[name].to(CPU) - # layers[name], scale, zero, g_idx = ( - # layers[name].to(CPU), - # scale.to(CPU), - # zero.to(CPU), - # g_idx.to(CPU), - # ) - # if QuantLinear is MarlinQuantLinear: - # qlayers[name].pack(layers[name], scale) - # else: - # qlayers[name].pack(layers[name], scale, zero, g_idx) - # qlayers[name].to(layer_device) + names = list(qlayers.keys()) + with ThreadPoolExecutor(max_workers=2) as executor: + with tqdm(total=len(names), leave=True) as pbar: + def wrapper(name): + pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar) + + for _ in executor.map(wrapper, names): + pass logger.info("Model packed.") - print(f"Time for pack: {time.time() - start}") return QuantLinear def verify_model_hash(file_path: str, verify_hash: str): From 2e546eafb9266ed7207630faeb1fdcbd946a3d84 Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 09:05:25 +0000 Subject: [PATCH 6/7] mod test code --- tests/test_quant_trust_remote.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_quant_trust_remote.py b/tests/test_quant_trust_remote.py index 9cf8fea5..06bb12c2 100644 --- a/tests/test_quant_trust_remote.py +++ b/tests/test_quant_trust_remote.py @@ -48,6 +48,8 @@ def test_diff_batch(self): del model py_files = [f for f in os.listdir(tmp_dir) if f.endswith('.py')] expected_files = ["modeling_minicpm.py", "configuration_minicpm.py"] - self.assertEqual(py_files, expected_files) + for file in expected_files: + self.assertIn(file, py_files, f"File {file} is missing in the actual files list") + From 3386fc855d51232fcee5435b4e23095b769db43b Mon Sep 17 00:00:00 2001 From: PZS-ModelCloud Date: Tue, 13 Aug 2024 09:05:48 +0000 Subject: [PATCH 7/7] format code --- gptqmodel/utils/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index fcab84c8..a8b83cdd 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -5,6 +5,7 @@ import os import re import shutil +from concurrent.futures import ThreadPoolExecutor from logging import getLogger from typing import List, Optional @@ -17,7 +18,6 @@ from tqdm import tqdm from transformers import AutoConfig, PretrainedConfig from transformers.utils.hub import cached_file -from concurrent.futures import ThreadPoolExecutor from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS from ..nn_modules.qlinear import BaseQuantLinear