From ec2b4307d4f7aa1b790834040104b3b6b0b61d25 Mon Sep 17 00:00:00 2001 From: yizt Date: Sat, 11 Jan 2020 20:15:32 +0800 Subject: [PATCH] inference image directory --- demo.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/demo.py b/demo.py index 82dea30..a70dbbd 100644 --- a/demo.py +++ b/demo.py @@ -5,14 +5,17 @@ @Author : yizuotian @Description : """ -import sys import argparse import itertools -import crnn -from PIL import Image +import os +import sys + +import cv2 import numpy as np import torch -import cv2 +from PIL import Image + +import crnn from config import cfg @@ -38,26 +41,39 @@ def load_image(image_path): return image -def main(args): - alpha = cfg.word.get_all_words() - net = crnn.CRNN(num_classes=len(alpha)) - net.load_state_dict(torch.load(args.weight_path, map_location='cpu')['model']) - net.eval() - # load image - image = load_image(args.image_path) +def inference_image(net, alpha, image_path): + image = load_image(image_path) image = torch.FloatTensor(image) predict = net(image)[0].detach().numpy() # [W,num_classes] label = np.argmax(predict[:], axis=1) label = [alpha[class_id] for class_id in label] - print(''.join(label)) label = [k for k, g in itertools.groupby(list(label))] - print(''.join(label)) + label = ''.join(label).replace(' ', '') + return label + + +def main(args): + alpha = cfg.word.get_all_words() + net = crnn.CRNN(num_classes=len(alpha)) + net.load_state_dict(torch.load(args.weight_path, map_location='cpu')['model']) + net.eval() + # load image + if args.image_dir: + image_path_list = [os.path.join(args.image_dir, n) for n in os.listdir(args.image_dir)] + image_path_list.sort() + for image_path in image_path_list: + label = inference_image(net, alpha, image_path) + print("image_path:{},label:{}".format(image_path, label)) + else: + label = inference_image(net, alpha, args.image_path) + print("image_path:{},label:{}".format(args.image_path, label)) if __name__ == '__main__': parse = argparse.ArgumentParser() parse.add_argument("--image-path", type=str, default=None, help="test image path") parse.add_argument("--weight-path", type=str, default=None, help="weight path") + parse.add_argument("--image-dir", type=str, default=None, help="test image directory") arguments = parse.parse_args(sys.argv[1:]) main(arguments)