diff --git a/rest.py b/rest.py index ca9731d..229ca6e 100644 --- a/rest.py +++ b/rest.py @@ -6,14 +6,20 @@ @Description : restful服务 """ +import argparse import base64 +import sys import cv2 import numpy as np +import torch import tornado.httpserver import tornado.wsgi from flask import Flask, request +import crnn +from config import cfg + app = Flask(__name__) @@ -66,4 +72,19 @@ def start_tornado(app, port=5000): if __name__ == '__main__': + parse = argparse.ArgumentParser() + parse.add_argument('-h', "--horizontal-weight-path", type=str, default=None, help="weight path") + parse.add_argument('-v', "--vertical-weight-path", type=str, default=None, help="weight path") + args = parse.parse_args(sys.argv[1:]) + alpha = cfg.word.get_all_words() + # 加载权重 + h_net = crnn.CRNN(num_classes=len(alpha)) + h_net.load_state_dict(torch.load(args.horizontal_weight_path, map_location='cpu')['model']) + h_net.eval() + # 垂直方向 + v_net = crnn.CRNN(num_classes=len(alpha)) + v_net.load_state_dict(torch.load(args.vertical_weight_path, map_location='cpu')['model']) + v_net.eval() + + # 启动restful服务 start_tornado(app, 5000)