From 56304c45691d0c3b54e723a35b485c730cd167f3 Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 15 Mar 2024 20:21:18 +0800 Subject: [PATCH 1/2] [transformer] remove pe to device --- wenet/transformer/embedding.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 17d8810ff..ba8deccd3 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -45,15 +45,16 @@ def __init__(self, self.dropout = torch.nn.Dropout(p=dropout_rate) self.max_len = max_len - self.pe = torch.zeros(self.max_len, self.d_model) + pe = torch.zeros(self.max_len, self.d_model) position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)) - self.pe[:, 0::2] = torch.sin(position * div_term) - self.pe[:, 1::2] = torch.cos(position * div_term) - self.pe = self.pe.unsqueeze(0) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = self.pe.unsqueeze(0) + self.register_buffer("pe", pe) def forward(self, x: torch.Tensor, @@ -70,7 +71,6 @@ def forward(self, torch.Tensor: for compatibility to RelPositionalEncoding """ - self.pe = self.pe.to(x.device) pos_emb = self.position_encoding(offset, x.size(1), False) x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -140,7 +140,6 @@ def forward(self, torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Positional embedding tensor (1, time, `*`). """ - self.pe = self.pe.to(x.device) x = x * self.xscale pos_emb = self.position_encoding(offset, x.size(1), False) return self.dropout(x), self.dropout(pos_emb) From 317b67baeebe236ab196cfa3a3c69a67c92b2cb7 Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 15 Mar 2024 20:23:35 +0800 Subject: [PATCH 2/2] fix self.pe --- wenet/transformer/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index ba8deccd3..1744ed839 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -53,7 +53,7 @@ def __init__(self, -(math.log(10000.0) / self.d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - pe = self.pe.unsqueeze(0) + pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self,