Skip to content

Commit

Permalink
Modify qlinear_cuda for tracing the GPTQ model (#367)
Browse files Browse the repository at this point in the history
Changes:
-- The change to the torch.bitwise_and is done because during
   tracing this model the current usage of the torch.bitwise_and
   result in an in-place variant of this op, resulting in an issue
   during the downstream lowering pipeline of the traced model via
   Torch-MLIR and IREE-SHARK. That's why the op usage is changed to
   not result in an in-place variaunt.

-- The change to the torch.matmul call in the forward function is
   done because currently, it assumes that the weights will always
   be of fp16 type. But, when the model is executed for the float32
   weights it results in an error. That's why the current change
   cast the LHS of the matmul to the same type as the RHS one.

Both the above changes doesn't affect the model in any way.

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Oct 20, 2023
1 parent 51c043c commit e4b2493
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def forward(self, x: torch.Tensor):
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0)
).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)

zeros = zeros + 1
zeros = zeros.reshape(self.scales.shape)
Expand All @@ -228,7 +228,7 @@ def forward(self, x: torch.Tensor):
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1)
).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
weight = torch.bitwise_and(weight, (2 ** self.bits) - 1)
elif self.bits == 3:
zeros = self.qzeros.reshape(
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
Expand Down Expand Up @@ -267,10 +267,10 @@ def forward(self, x: torch.Tensor):
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
weights = torch.cat(weights,dim=1)
out = torch.matmul(x.half(), weights)
out = torch.matmul(x.to(weights.dtype), weights)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
return out.to(x.dtype)


__all__ = ["QuantLinear"]
8 changes: 4 additions & 4 deletions auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(self, x):

if self.bits in [2,4,8]:
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)

zeros = zeros + 1
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
Expand All @@ -238,7 +238,7 @@ def forward(self, x):
scales = scales.reshape(-1, 1, scales.shape[-1])

weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight)
weight = torch.bitwise_and(weight,(2 ** self.bits) - 1)
weight = weight.reshape(-1, self.group_size, weight.shape[2])
elif self.bits == 3:
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
Expand Down Expand Up @@ -266,10 +266,10 @@ def forward(self, x):
weight = (scales * (weight - zeros))
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

out = torch.matmul(x.half(), weight)
out = torch.matmul(x.to(weight.dtype), weight)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
return out.to(x.dtype)


__all__ = ["QuantLinear"]

0 comments on commit e4b2493

Please sign in to comment.