diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 17d8810ff..1744ed839 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -45,15 +45,16 @@ def __init__(self, self.dropout = torch.nn.Dropout(p=dropout_rate) self.max_len = max_len - self.pe = torch.zeros(self.max_len, self.d_model) + pe = torch.zeros(self.max_len, self.d_model) position = torch.arange(0, self.max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)) - self.pe[:, 0::2] = torch.sin(position * div_term) - self.pe[:, 1::2] = torch.cos(position * div_term) - self.pe = self.pe.unsqueeze(0) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) def forward(self, x: torch.Tensor, @@ -70,7 +71,6 @@ def forward(self, torch.Tensor: for compatibility to RelPositionalEncoding """ - self.pe = self.pe.to(x.device) pos_emb = self.position_encoding(offset, x.size(1), False) x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -140,7 +140,6 @@ def forward(self, torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Positional embedding tensor (1, time, `*`). """ - self.pe = self.pe.to(x.device) x = x * self.xscale pos_emb = self.position_encoding(offset, x.size(1), False) return self.dropout(x), self.dropout(pos_emb)