-
Notifications
You must be signed in to change notification settings - Fork 9
/
text_birnn.py
54 lines (43 loc) · 2.12 KB
/
text_birnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
from Config import config
def textbirnn(input_x, dropout_keep_prob, dataset, reuse=False):
"""
A Bi-directional RNN for text classification.
Uses an embedding layer, followed by a bi-directional LSTM layer, a dropout layer and a fully-connected layer.
"""
num_classes = config.num_classes[dataset]
vocab_size = config.num_words[dataset]
embedding_size = 300
# Embedding layer
with tf.variable_scope("embedding", reuse=reuse):
embeddings = tf.get_variable("W",
initializer=tf.random_uniform([vocab_size+1, embedding_size], -1.0, 1.0),
trainable=True)
embedded_chars = tf.nn.embedding_lookup(embeddings, input_x, name="embedded_chars") # [None, sequence_length, embedding_size]
# Create a bi-directional LSTM layer for each rnn layer
with tf.variable_scope('bilstm', reuse=reuse):
cell_fun = tf.nn.rnn_cell.BasicLSTMCell
def get_bi_cell():
fw_cell = cell_fun(128, state_is_tuple=True) #forward direction cell
bw_cell = cell_fun(128, state_is_tuple=True) #backward direction cell
return fw_cell, bw_cell
# Bi-lstm layer
fw_cell, bw_cell = get_bi_cell()
outputs, last_state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, embedded_chars, dtype=tf.float32)
outputs = tf.concat(outputs, axis=2)
output = tf.reduce_mean(outputs, axis=1)
# Add dropout
with tf.variable_scope("dropout", reuse=reuse):
rnn_drop = tf.nn.dropout(output, dropout_keep_prob)
# Final (unnormalized) scores and predictions
with tf.variable_scope("output", reuse=reuse):
W = tf.get_variable(
"W",
shape=[128*2, num_classes],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable("b", initializer=tf.constant(0.1, shape=[num_classes]))
scores = tf.nn.xw_plus_b(rnn_drop, W, b, name="scores")
predictions = tf.argmax(scores, 1, name="predictions", output_type=tf.int32)
return embeddings, embedded_chars, predictions, scores