Skip to content

Commit

Permalink
inference image directory
Browse files Browse the repository at this point in the history
  • Loading branch information
yizt committed Jan 11, 2020
1 parent a5315ee commit ec2b430
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
@Author : yizuotian
@Description :
"""
import sys
import argparse
import itertools
import crnn
from PIL import Image
import os
import sys

import cv2
import numpy as np
import torch
import cv2
from PIL import Image

import crnn
from config import cfg


Expand All @@ -38,26 +41,39 @@ def load_image(image_path):
return image


def main(args):
alpha = cfg.word.get_all_words()
net = crnn.CRNN(num_classes=len(alpha))
net.load_state_dict(torch.load(args.weight_path, map_location='cpu')['model'])
net.eval()
# load image
image = load_image(args.image_path)
def inference_image(net, alpha, image_path):
image = load_image(image_path)
image = torch.FloatTensor(image)

predict = net(image)[0].detach().numpy() # [W,num_classes]
label = np.argmax(predict[:], axis=1)
label = [alpha[class_id] for class_id in label]
print(''.join(label))
label = [k for k, g in itertools.groupby(list(label))]
print(''.join(label))
label = ''.join(label).replace(' ', '')
return label


def main(args):
alpha = cfg.word.get_all_words()
net = crnn.CRNN(num_classes=len(alpha))
net.load_state_dict(torch.load(args.weight_path, map_location='cpu')['model'])
net.eval()
# load image
if args.image_dir:
image_path_list = [os.path.join(args.image_dir, n) for n in os.listdir(args.image_dir)]
image_path_list.sort()
for image_path in image_path_list:
label = inference_image(net, alpha, image_path)
print("image_path:{},label:{}".format(image_path, label))
else:
label = inference_image(net, alpha, args.image_path)
print("image_path:{},label:{}".format(args.image_path, label))


if __name__ == '__main__':
parse = argparse.ArgumentParser()
parse.add_argument("--image-path", type=str, default=None, help="test image path")
parse.add_argument("--weight-path", type=str, default=None, help="weight path")
parse.add_argument("--image-dir", type=str, default=None, help="test image directory")
arguments = parse.parse_args(sys.argv[1:])
main(arguments)

0 comments on commit ec2b430

Please sign in to comment.