Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"fix style" #66

Merged
merged 2 commits into from
Jan 19, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 11 additions & 44 deletions fluid/stacked_dynamic_lstm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import cPickle
import os
Expand Down Expand Up @@ -44,49 +48,11 @@ def parse_args():
default=int(os.environ.get('CROP_SIZE', '1500')),
help='The max sentence length of input. Since this model use plain RNN,'
' Gradient could be explored if sentence is too long')
parser.add_argument(
'--clean',
type=bool,
default=os.environ.get('CLEAN', 'False').lower() != 'false',
help='clean the cached pickle file.')
args = parser.parse_args()
return args


try:
with open('word_dict.pkl', 'r') as f:
word_dict = cPickle.load(f)
except:
word_dict = imdb.word_dict()
with open('word_dict.pkl', 'w') as f:
cPickle.dump(word_dict, f, cPickle.HIGHEST_PROTOCOL)


def cache_reader(reader, clean):
print 'Reading data to memory'
fn = 'data.pkl'
if clean:
try:
os.remove(fn)
except:
pass
try:
with open(fn, 'r') as f:
items = cPickle.load(f)
except:
items = list(reader())
with open(fn, 'w') as f:
cPickle.dump(items, f, cPickle.HIGHEST_PROTOCOL)

print 'Done. data size %d' % len(items)

def __impl__():
offsets = range(len(items))
random.shuffle(offsets)
for i in offsets:
yield items[i]

return __impl__
word_dict = imdb.word_dict()


def crop_sentence(reader, crop_size):
Expand Down Expand Up @@ -132,7 +98,7 @@ def gate_common(
x=gate_common(word, prev_hidden, lstm_size))
output_gate = fluid.layers.sigmoid(
x=gate_common(word, prev_hidden, lstm_size))
cell_gate = fluid.layers.sigmoid(
cell_gate = fluid.layers.tanh(
x=gate_common(word, prev_hidden, lstm_size))

cell = fluid.layers.sums(input=[
Expand Down Expand Up @@ -164,10 +130,10 @@ def gate_common(
exe.run(fluid.default_startup_program())

def train_loop(pass_num, crop_size):
cache = cache_reader(
crop_sentence(imdb.train(word_dict), crop_size), clean=args.clean)
for pass_id in range(pass_num):
train_reader = batch(cache, batch_size=args.batch_size)
train_reader = batch(
crop_sentence(imdb.train(word_dict), crop_size),
batch_size=args.batch_size)

pass_start_time = time.time()
for batch_id, data in enumerate(train_reader()):
Expand All @@ -178,7 +144,8 @@ def train_loop(pass_num, crop_size):
feed={"words": tensor_words,
"label": label},
fetch_list=[loss])[0]
print 'Pass', pass_id, 'Batch', batch_id, 'loss', loss_np
print("pass_id=%d, batch_id=%d, loss: %f" %
(pass_id, batch_id, loss_np))

pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
Expand Down