-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
paddle v2 CTC error在训练日志中不显示 #3802
Comments
@xieshufu 你好,麻烦贴一下完整的配置? |
import sys
import paddle.v2 as paddle
from ocr_8conv import ocr_8conv_net
from ocr_4conv import ocr_4conv_net
#from ocr_reader import train_reader, test_reader
from ocr_data import DataGenerator
def main():
datadim = 48 * 48 * 1
classdim = 21501
# PaddlePaddle init
paddle.init(use_gpu=True, trainer_count=1)
image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(datadim))
# Add neural network config
# option 1. resnet
# net = resnet_cifar10(image, depth=32)
# option 2. vgg
output = ocr_4conv_net(image, classdim+1)
lbl = paddle.layer.data(
name="label", type=paddle.data_type.integer_value_sequence(classdim))
cost = paddle.layer.warp_ctc(input=output,
label=lbl,
size=classdim+1,
blank=classdim,
norm_by_times=True)
ctc_eval = paddle.evaluator.ctc_error(input=output, label=lbl)
# Create parameters
model_path = ""
if model_path == "":
parameters = paddle.parameters.create(cost)
else:
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
train_list_path = "./train_3W.list"
test_list_path = "./test_image.list"
train_generator = DataGenerator(file_list_path=train_list_path)
test_generator = DataGenerator(file_list_path=test_list_path)
train_batch_reader = train_generator.batch_train_reader_creator(batch_size=32)
test_batch_reader = test_generator.batch_test_reader_creator(batch_size=1)
# Create optimizer
momentum_optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.001)
# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if event.batch_id % 100 == 0:
result = trainer.test(
reader=test_batch_reader,
feeding={'image': 0,
'label': 1})
print "\nTest with Pass %d_%d, %s" % (event.pass_id, event.batch_id, result.metrics)
if isinstance(event, paddle.event.EndPass):
# save parameters
with open('./result/params_pass_%d.tar' % event.pass_id, 'w') as f:
parameters.to_tar(f)
result = trainer.test(
reader=test_batch_reader,
feeding={'image': 0,
'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=momentum_optimizer,
extra_layers=ctc_eval)
trainer.train(
reader=train_batch_reader,
num_passes=200,
event_handler=event_handler,
feeding={'image': 0,'label': 1})
if __name__ == '__main__':
main() |
CTC error evaluator 在v2 下面确实无法输出,这个问题应该被fix 一下。 |
CTC evaluator 遇到的问题和 CRF evaluator 一样,可以按照和这个PR一样的原理fix一下 |
旧版本PADDLE和paddle v2在训练集上的信息输出不对应, 旧版本的信息会输出两部分:一部分是所过的所有样本的度量,另一部分是当前这个period里所过样本的度量; paddle v2的信息里只有一部分,这部分是和旧版本PADDLE日志里的哪部分对应? paddle v2的日志输出如下: |
paddle v2的训练错误率,会出现比较大的跳动,见下表:
旧版本PADDLE的训练错误率则不会如此显著:
|
paddle v2打出来的信息,是当前这个period里所过样本的度量,历史所过的所有样本的度量需要用户自己算。 |
因为是序列识别模型的评测,涉及到了串与串之间的序列错误率、插入、删除、替换等信息,这个在外面不太好计算吧。 |
配置了一个网络,在PADDLE v2下运行,在增加了ctc_error度量后,但在日志里不显示error信息:
训练日志的显示信息如下:
烦请PADDLE同学进行支持!
希望能够显示的信息类似这样的, 方便对模型的精度做判断:
The text was updated successfully, but these errors were encountered: