Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support automatic detection of amp and device Autoround [2.x] #1649

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/ut/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
fi

if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
pip install git+https://github.com/intel/auto-round.git@b65830f3f6cb32d92a5c8ba5f80ace12d517357b
pip install git+https://github.com/intel/auto-round.git@6815f8b66be456ecbef2d0beb33dbc4efeefdc04
fi

# test deps
Expand Down
8 changes: 0 additions & 8 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4918,8 +4918,6 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
# auto round recipes
enable_full_range = self.recipes["autoround_args"].get("enable_full_range", False)
bs = self.recipes["autoround_args"].get("bs", 8)
amp = self.recipes["autoround_args"].get("amp", True)
device = self.recipes["autoround_args"].get("device", "cpu")
lr_scheduler = self.recipes["autoround_args"].get("lr_scheduler", None)
dataset_name = self.recipes["autoround_args"].get("dataset_name", "NeelNanda/pile-10k")
dataset_split = self.recipes["autoround_args"].get("dataset_split", "train")
Expand All @@ -4939,8 +4937,6 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1)
data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type
scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16")
# autoround export
export_args = self.recipes["autoround_args"].get("export_args", {"format": None})

model, autoround_config = autoround_quantize(
model=model,
Expand All @@ -4951,8 +4947,6 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
weight_config=weight_config,
enable_full_range=enable_full_range,
bs=bs,
amp=amp,
device=device,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
dataset_name=dataset_name,
Expand All @@ -4973,8 +4967,6 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
dynamic_max_gap=dynamic_max_gap,
data_type=data_type,
scale_dtype=scale_dtype,
# export arguments
export_args=export_args,
)
return model, autoround_config

Expand Down
16 changes: 3 additions & 13 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def autoround_quantize(
enable_full_range: bool = False, ##for symmetric, TODO support later
bs: int = 8,
amp: bool = True,
device="cuda:0",
device=None,
lr_scheduler=None,
dataloader=None, ## to support later
dataset_name: str = "NeelNanda/pile-10k",
Expand All @@ -703,7 +703,6 @@ def autoround_quantize(
dynamic_max_gap: int = -1,
data_type: str = "int", ##only support data_type
scale_dtype="fp16",
export_args: dict = {"format": None, "inplace": True},
**kwargs,
):
"""Run autoround weight-only quantization.
Expand All @@ -726,8 +725,8 @@ def autoround_quantize(
}
enable_full_range (bool): Whether to enable full range quantization (default is False).
bs (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for tuning (default is "cuda:0").
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
device: The device to be used for tuning (default is None). Automatically detect and set.
lr_scheduler: The learning rate scheduler to be used.
dataloader: The dataloader for input data (to be supported in future).
dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k").
Expand All @@ -747,8 +746,6 @@ def autoround_quantize(
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
export_args (dict): The arguments for exporting compressed model, default is {"format": None, "inplace": True}.
Supported format: "itrex", "auto_gptq".
**kwargs: Additional keyword arguments.

Returns:
Expand Down Expand Up @@ -790,11 +787,4 @@ def autoround_quantize(
**kwargs,
)
qdq_model, weight_config = rounder.quantize()
if export_args["format"] is not None:
output_dir = export_args.get("output_dir", None)
format = export_args["format"]
inplace = export_args.get("inplace", True)
use_triton = export_args.get("use_triton", False)
model = rounder.save_quantized(output_dir=output_dir, format=format, inplace=inplace, use_triton=use_triton)
return model, weight_config
return qdq_model, weight_config
4 changes: 2 additions & 2 deletions neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,9 @@ def export_compressed_model(
new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm)
set_module(self.model, k, new_module)
elif autoround_config:
from auto_round.export.export_to_itrex import compress_model # pylint: disable=E0401
from auto_round.export.export_to_itrex.export import _pack_model # pylint: disable=E0401

self.model = compress_model(
self.model = _pack_model(
self.model,
weight_config=autoround_config,
enable_full_range=enable_full_range,
Expand Down
55 changes: 0 additions & 55 deletions test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,11 +778,9 @@ def test_AutoRound_quant(self):
recipes={
"autoround_args": {
"n_samples": 20,
"amp": False,
"seq_len": 10,
"iters": 10,
"scale_dtype": "fp32",
"device": "cpu",
},
},
)
Expand All @@ -809,59 +807,6 @@ def test_AutoRound_quant(self):
self.assertTrue(isinstance(q_model.model.transformer.h[0].attn.k_proj, WeightOnlyLinear))
self.assertTrue(isinstance(export_model.transformer.h[0].attn.k_proj, WeightOnlyLinear))

fp32_model = copy.deepcopy(self.gptj)

conf = PostTrainingQuantConfig(
approach="weight_only",
op_type_dict={
".*": { # re.match
"weight": {
"dtype": "int",
"bits": 4,
"group_size": 32, # -1 (per-channel)
"scheme": "sym",
"algorithm": "AUTOROUND",
},
},
},
op_name_dict={
".*lm_head": { # re.match
"weight": {"dtype": "fp32"},
},
},
recipes={
"autoround_args": {
"n_samples": 20,
"amp": False,
"seq_len": 10,
"iters": 10,
"scale_dtype": "fp32",
"device": "cpu",
"export_args": {"format": "itrex", "inplace": False},
},
},
)
"""All export arguments.

"export_args": {
"format": "itrex", # "iterx", "auto_gptq", default is None
"output_dir": None, # saved path
"inplace": False,
"use_triton": False,
}
"""
input = torch.ones([1, 512], dtype=torch.long)
fp32_model = copy.deepcopy(self.gptj)
out1 = fp32_model(input)
export_model = quantization.fit(
fp32_model,
conf,
calib_dataloader=dataloader,
)
out2 = export_model.model(input)
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01))
self.assertTrue(isinstance(export_model.model.transformer.h[0].attn.k_proj, WeightOnlyLinear))


if __name__ == "__main__":
unittest.main()
30 changes: 0 additions & 30 deletions test/quantization/test_weight_only_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ def test_autoround_int_quant(self):
model=model,
tokenizer=self.tokenizer,
n_samples=20,
device=device,
amp=False,
seqlen=10,
iters=10,
scale_dtype="fp32",
Expand All @@ -292,34 +290,6 @@ def test_autoround_int_quant(self):
self.assertFalse(torch.all(out1[0] == out2[0]))
self.assertTrue(torch.all(out2[0] == out3[0]))

def test_autoround_export(self):
model = copy.deepcopy(self.gptj)
device = "cpu"
model = model
out1 = model(self.lm_input)
export_model, weight_config1 = autoround_quantize(
model=model,
tokenizer=self.tokenizer,
n_samples=20,
device=device,
amp=False,
seqlen=10,
iters=10,
scale_dtype="fp32",
export_args={"format": "itrex", "inplace": True},
)
export_model = export_model
model = model
out2 = model(self.lm_input)
out3 = export_model(self.lm_input)
self.assertTrue(torch.all(torch.isclose(out1[0], out2[0], atol=1e-1)))
self.assertFalse(torch.all(out1[0] == out2[0]))
self.assertTrue(torch.all(out2[0] == out3[0]))

from auto_round.export.export_to_itrex.model_wrapper import WeightOnlyLinear

self.assertTrue(isinstance(export_model.transformer.h[0].attn.k_proj, WeightOnlyLinear))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
accelerate==0.21.0
dynast==1.6.0rc1
git+https://github.com/intel/auto-round.git@b65830f3f6cb32d92a5c8ba5f80ace12d517357b
git+https://github.com/intel/auto-round.git@6815f8b66be456ecbef2d0beb33dbc4efeefdc04
horovod
intel-extension-for-pytorch
intel-tensorflow>=2.12.0
Expand Down
Loading