From c006921e9e9977bc107b05676266b581091688a2 Mon Sep 17 00:00:00 2001 From: Jakub Bartczuk <3647577+lambdaofgod@users.noreply.github.com> Date: Mon, 18 Dec 2023 16:04:52 +0100 Subject: [PATCH] added rnn.pack_padded_sequence argument cpu conversion (#1420) --- sentence_transformers/models/LSTM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/models/LSTM.py b/sentence_transformers/models/LSTM.py index a79cfacc5..31782d540 100644 --- a/sentence_transformers/models/LSTM.py +++ b/sentence_transformers/models/LSTM.py @@ -29,7 +29,7 @@ def forward(self, features): token_embeddings = features['token_embeddings'] sentence_lengths = torch.clamp(features['sentence_lengths'], min=1) - packed = nn.utils.rnn.pack_padded_sequence(token_embeddings, sentence_lengths, batch_first=True, enforce_sorted=False) + packed = nn.utils.rnn.pack_padded_sequence(token_embeddings, sentence_lengths.cpu(), batch_first=True, enforce_sorted=False) packed = self.encoder(packed) unpack = nn.utils.rnn.pad_packed_sequence(packed[0], batch_first=True)[0] features.update({'token_embeddings': unpack})