diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 0cb6d6d938d..8f46b778ec5 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -1105,7 +1105,11 @@ def __iter__(self): if not args: yield kwargs elif not kwargs: - yield args + # case: tensor + if len(args) == 1: + yield args[0] + else: + yield args else: yield args, kwargs diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 27d30753cdb..f4eb777f5d9 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -740,7 +740,7 @@ def __init__( minmax_lr: float = None, low_gpu_mem_usage: bool = True, iters: int = 200, - seqlen: int = 2048, + seqlen: int = 512, n_samples: int = 512, sampler: str = "rand", seed: int = 42, @@ -1507,8 +1507,7 @@ def get_woq_tuning_config() -> list: the list of WOQ quant config. """ RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32) + AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32) GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32) - GPTQ_G32ASYM_DISABLE_LAST_LINEAR = GPTQConfig(use_sym=False).set_local("*.lm_head", GPTQConfig(dtype="fp32")) - GPTQ_G128ASYM = GPTQConfig(group_size=128, use_sym=False) AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32) - return [RTN_G32ASYM, GPTQ_G32ASYM, GPTQ_G32ASYM_DISABLE_LAST_LINEAR, GPTQ_G128ASYM, AWQ_G32ASYM] + return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM] diff --git a/test/3x/torch/quantization/weight_only/test_woq_utils.py b/test/3x/torch/quantization/weight_only/test_woq_utils.py index c31d94b823d..3bee40696c8 100644 --- a/test/3x/torch/quantization/weight_only/test_woq_utils.py +++ b/test/3x/torch/quantization/weight_only/test_woq_utils.py @@ -169,7 +169,16 @@ def test_captured_dataloader_iteration(self): result = list(dataloader) - assert result == [(1,), (2,), (3,)] + assert result == [1, 2, 3] + + # Test case when kwargs is empty + args_list = [(1, 2), (2, 3), (3, 4)] + kwargs_list = [{}, {}, {}] + dataloader = CapturedDataloader(args_list, kwargs_list) + + result = list(dataloader) + + assert result == [(1, 2), (2, 3), (3, 4)] # Test case when both args and kwargs are present args_list = [(1,), (2,), (3,)]