Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model average option for OCR CTC model #740

Merged
merged 3 commits into from
Mar 27, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions fluid/ocr_recognition/crnn_ctc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def conv_bn_pool(input,
bias_attr=bias,
is_test=is_test)
tmp = fluid.layers.pool2d(
input=tmp, pool_size=2, pool_type='max', pool_stride=2, use_cudnn=True)
input=tmp,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=True,
ceil_mode=True)

return tmp

Expand Down Expand Up @@ -148,14 +153,20 @@ def ctc_train_net(images, label, args, num_classes):

optimizer = fluid.optimizer.Momentum(
learning_rate=args.learning_rate, momentum=args.momentum)
optimizer.minimize(sum_cost)

_, params_grads = optimizer.minimize(sum_cost)
model_average = None
if args.model_average:
model_average = fluid.optimizer.ModelAverage(
params_grads,
args.average_window,
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
casted_label = fluid.layers.cast(x=label, dtype='int64')
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
return sum_cost, error_evaluator
return sum_cost, error_evaluator, model_average


def ctc_infer(images, num_classes):
Expand Down
28 changes: 17 additions & 11 deletions fluid/ocr_recognition/ctc_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Trainer for OCR CTC model."""
import paddle.v2 as paddle
import paddle.fluid as fluid
import dummy_reader
import ctc_reader
Expand All @@ -24,12 +23,16 @@
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
add_arg('model_average', bool, True, "Whether to aevrage model for evaluation.")
add_arg('min_average_window', int, 10000, "Min average window.")
add_arg('max_average_window', int, 15625, "Max average window.")
add_arg('average_window', float, 0.15, "Average window.")

# yapf: disable

def load_parameter(place):
params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/')
for name in params:
# print "param: %s" % name
t = fluid.global_scope().find_var(name).get_tensor()
t.set(params[name], place)

Expand All @@ -41,7 +44,8 @@ def train(args, data_reader=dummy_reader):
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1)
sum_cost, error_evaluator = ctc_train_net(images, label, args, num_classes)
sum_cost, error_evaluator, model_average = ctc_train_net(images, label, args, num_classes)

# data reader
train_reader = data_reader.train(args.batch_size)
test_reader = data_reader.test()
Expand All @@ -51,7 +55,6 @@ def train(args, data_reader=dummy_reader):
place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

#load_parameter(place)

inference_program = fluid.io.get_inference_program(error_evaluator)
Expand All @@ -78,13 +81,16 @@ def train(args, data_reader=dummy_reader):
sys.stdout.flush()
batch_id += 1

# evaluate model on test data
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0]))
with model_average.apply(exe):
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
if model_average != None:
model_average.restore(exe)

print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0]))

def main():
args = parser.parse_args()
Expand Down