From faae87c1e022b2f4a8c7b77e0dec97d1c0c730e0 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Tue, 6 Jun 2023 08:46:22 -0700 Subject: [PATCH] Remove contiguous() in split_embedding_weights_with_scale_bias Summary: Calling `tensor.contiguous()` in case of non-contiguous tensor creates a new tensor. Changing it will not change the original `tensor`. To use results of `split_embedding_weights_with_scale_bias(split_scale_bias_mode=2)` as a tensor in state_dict - we should be able via that tensor to change the original tbe weight. For that we need to remove copy via contiguous(). Differential Revision: D46483112 Privacy Context Container: L1138451 fbshipit-source-id: dffaad9c7e92aaaf7761958c0d190a7ed21ee00e --- .../split_table_batched_embeddings_ops_inference.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f91a358cd3..9c6d1b5b19 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -1270,16 +1270,14 @@ def split_embedding_weights_with_scale_bias( splits.append( ( weights_shifts[:, self.scale_bias_size_in_bytes :], - weights_shifts[:, : self.scale_bias_size_in_bytes // 2] - .contiguous() - .view(torch.float16), + weights_shifts[ + :, : self.scale_bias_size_in_bytes // 2 + ].view(torch.float16), weights_shifts[ :, self.scale_bias_size_in_bytes // 2 : self.scale_bias_size_in_bytes, - ] - .contiguous() - .view(torch.float16), + ].view(torch.float16), ) ) elif (