Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify qlinear_cuda for tracing the GPTQ model #367

Merged
merged 1 commit into from
Oct 20, 2023

Conversation

vivekkhandelwal1
Copy link
Contributor

Changes:
-- The change to the torch.bitwise_and is done because during
tracing this model the current usage of the torch.bitwise_and
result in an in-place variant of this op, resulting in an issue
during the downstream lowering pipeline of the traced model via
Torch-MLIR and IREE-SHARK. That's why the op usage is changed to
not result in an in-place variaunt.

-- The change to the torch.matmul call in the forward function is
done because currently, it assumes that the weights will always
be of fp16 type. But, when the model is executed for the float32
weights it results in an error. That's why the current change
cast the LHS of the matmul to the same type as the RHS one.

Both the above changes doesn't affect the model in any way.

Signed-Off By: Vivek Khandelwal [email protected]

@vivekkhandelwal1
Copy link
Contributor Author

@PanQiWei, can you please take a look at this PR?

@vivekkhandelwal1
Copy link
Contributor Author

@qwopqwop200 @TheBloke @fxmarty @PanQiWei Can you please review this PR? I'm blocked on this PR right now. I would be really grateful if we can get this merged soon.

Changes:
-- The change to the torch.bitwise_and is done because during
   tracing this model the current usage of the torch.bitwise_and
   result in an in-place variant of this op, resulting in an issue
   during the downstream lowering pipeline of the traced model via
   Torch-MLIR and IREE-SHARK. That's why the op usage is changed to
   not result in an in-place variaunt.

-- The change to the torch.matmul call in the forward function is
   done because currently, it assumes that the weights will always
   be of fp16 type. But, when the model is executed for the float32
   weights it results in an error. That's why the current change
   cast the LHS of the matmul to the same type as the RHS one.

Both the above changes doesn't affect the model in any way.

Signed-Off By: Vivek Khandelwal <[email protected]>
@fxmarty
Copy link
Collaborator

fxmarty commented Oct 20, 2023

Apology for the delay.

Both the above changes doesn't affect the model in any way.

well it kind of does, I guess there's an additional memory allocation instead of doing the operation in-place? Sounds not like a big deal though.

I am wondering though what is the point of using AutoGPTQ if self.autogptq_cuda_available is False? It feels like the python implementation must be very slow. Torch-MLIR needs to be able to lower every branches of controlflows?

@@ -267,10 +267,10 @@ def forward(self, x: torch.Tensor):
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1)
out = torch.matmul(x.half(), weights)
out = torch.matmul(x.to(weights.dtype), weights)
out = out.half().reshape(out_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there still a .half() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this. I think this was left by mistake. Should I remove this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind leaving it you if this half() is not a blocker for you. But was wondering given that you replaced some of the .half() to remove the assumption on fp16.

@vivekkhandelwal1
Copy link
Contributor Author

Apology for the delay.

Both the above changes doesn't affect the model in any way.

well it kind of does, I guess there's an additional memory allocation instead of doing the operation in-place? Sounds not like a big deal though.

I am wondering though what is the point of using AutoGPTQ if self.autogptq_cuda_available is False? It feels like the python implementation must be very slow. Torch-MLIR needs to be able to lower every branches of controlflows?

So, the model would be run on a GPU but after getting lowered through Torch-MLIR and compiled via IREE. The changes are done because Torch-MLIR doesn't support tensors on CUDA device, so the model is lowered on CPU, and then lowered via Torch-MLIR and after compilation we run it on the GPUs for CUDA, Vulkan and Rocm backend.

@fxmarty
Copy link
Collaborator

fxmarty commented Oct 20, 2023

That's neat! Have you been able to run AutoGPTQ this way? Is it competitive on say CUDA compared to the homemade kernel?

@vivekkhandelwal1
Copy link
Contributor Author

That's neat! Have you been able to run AutoGPTQ this way? Is it competitive on say CUDA compared to the homemade kernel?

Yeah! I have been able to run the falcon-180b-GPTQ on the CPU and the falcon-7b-GPTQ on the CPU and CUDA.

@vivekkhandelwal1
Copy link
Contributor Author

We have been blocked on this patch for weeks. Also, my changes in huggingface/transformers#26719 are also dependent on this PR getting merged. Please let me know if I need to make any further changes to get this in.

Copy link
Collaborator

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

Apology for the delay and inconvenience, I am not maintaining the repo so am not actively checking the PRs, I just happen to have rights.

@fxmarty fxmarty merged commit e4b2493 into AutoGPTQ:main Oct 20, 2023
@vivekkhandelwal1
Copy link
Contributor Author

LGTM, thanks a lot!

Apology for the delay and inconvenience, I am not maintaining the repo so am not actively checking the PRs, I just happen to have rights.

Thanks for merging this patch!

out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
return out.to(x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a bug when using the CUDA kernel in fp32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of bug? Can you please explain?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_cuda_fp16=False, there is a cast x = x.to(torch.float32) which results in the output dtype being wrong with the change above. This is fixed in #382.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fxmarty!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants