forked from lingyongyan/Neural-Machine-Translation
-
Notifications
You must be signed in to change notification settings - Fork 10
/
encoder.py
31 lines (25 loc) · 1.07 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
class EncoderRNN(nn.Module):
"""Recurrent neural network that encodes a given input sequence."""
def __init__(self, src_vocab_size, embedding_size, hidden_size, n_layers=1, dropout=0.1):
super(EncoderRNN, self).__init__()
self.src_vocab_size = src_vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding = nn.Embedding(src_vocab_size, hidden_size)
self.dropout = nn.Dropout(dropout)
self.rnn = nn.GRU(hidden_size, hidden_size, n_layers)
def forward(self, inputs, hidden_state):
"""
inputs: [len]
"""
inputs = inputs.view(-1, 1)
embedded = self.embedding(inputs) # [len, 1, embedding_size]
embedded = self.dropout(embedded)
output, hidden_state = self.rnn(embedded, hidden_state)
return output, hidden_state
def init_hidden(self, device):
hidden_state = torch.zeros(self.n_layers, 1, self.hidden_size).to(device)
return hidden_state