Skip to content

Commit

Permalink
Remove contiguous() in split_embedding_weights_with_scale_bias
Browse files Browse the repository at this point in the history
Summary:
Calling `tensor.contiguous()` in case of non-contiguous tensor creates a new tensor.
Changing it will not change the original `tensor`.

To use results of `split_embedding_weights_with_scale_bias(split_scale_bias_mode=2)` as a tensor in state_dict - we should be able via that tensor to change the original tbe weight.

For that we need to remove copy via contiguous().

Differential Revision:
D46483112

Privacy Context Container: L1138451

fbshipit-source-id: dffaad9c7e92aaaf7761958c0d190a7ed21ee00e
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Jun 6, 2023
1 parent e9d7e3e commit faae87c
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1270,16 +1270,14 @@ def split_embedding_weights_with_scale_bias(
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[:, : self.scale_bias_size_in_bytes // 2]
.contiguous()
.view(torch.float16),
weights_shifts[
:, : self.scale_bias_size_in_bytes // 2
].view(torch.float16),
weights_shifts[
:,
self.scale_bias_size_in_bytes
// 2 : self.scale_bias_size_in_bytes,
]
.contiguous()
.view(torch.float16),
].view(torch.float16),
)
)
elif (
Expand Down

0 comments on commit faae87c

Please sign in to comment.