Skip to content

Commit

Permalink
加载模型
Browse files Browse the repository at this point in the history
  • Loading branch information
yizt committed Apr 6, 2020
1 parent da2ad59 commit 59ac6a9
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
@Description : restful服务
"""

import argparse
import base64
import sys

import cv2
import numpy as np
import torch
import tornado.httpserver
import tornado.wsgi
from flask import Flask, request

import crnn
from config import cfg

app = Flask(__name__)


Expand Down Expand Up @@ -66,4 +72,19 @@ def start_tornado(app, port=5000):


if __name__ == '__main__':
parse = argparse.ArgumentParser()
parse.add_argument('-h', "--horizontal-weight-path", type=str, default=None, help="weight path")
parse.add_argument('-v', "--vertical-weight-path", type=str, default=None, help="weight path")
args = parse.parse_args(sys.argv[1:])
alpha = cfg.word.get_all_words()
# 加载权重
h_net = crnn.CRNN(num_classes=len(alpha))
h_net.load_state_dict(torch.load(args.horizontal_weight_path, map_location='cpu')['model'])
h_net.eval()
# 垂直方向
v_net = crnn.CRNN(num_classes=len(alpha))
v_net.load_state_dict(torch.load(args.vertical_weight_path, map_location='cpu')['model'])
v_net.eval()

# 启动restful服务
start_tornado(app, 5000)

0 comments on commit 59ac6a9

Please sign in to comment.