diff --git a/neural_compressor/adaptor/torch_utils/teq.py b/neural_compressor/adaptor/torch_utils/teq.py index 716dcf236b2..4c1298c641e 100644 --- a/neural_compressor/adaptor/torch_utils/teq.py +++ b/neural_compressor/adaptor/torch_utils/teq.py @@ -256,7 +256,9 @@ def train( while global_steps <= train_steps: for inputs in dataloader: - if isinstance(inputs, dict): + if isinstance(inputs, torch.Tensor): + input_id = inputs + elif isinstance(inputs, dict): input_id = inputs["input_ids"] else: input_id = inputs[0]