-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
73 lines (58 loc) · 2.4 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI, File, UploadFile, Request, Form
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from typing import Optional
from PIL import Image
import requests
import uvicorn
import io
import os
import time
import argparse
from pydantic import BaseModel
from ultocr.inference import OCR
def download_image(image_url):
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content)).convert('RGB')
return image
class Result(BaseModel):
url: str
status_code: str
description: Optional[str] = None
text: Optional[str] = None
latency: Optional[str] = None
def return_response(response):
json_compatible_response_data = jsonable_encoder(response)
return JSONResponse(json_compatible_response_data)
def parse_args():
parser = argparse.ArgumentParser(description='Hyper parameter')
parser.add_argument('--det_model', type=str, default='DB', help='text detection model')
parser.add_argument('--reg_model', type=str, default='MASTER', help='text recognition model')
parser.add_argument('--det_config', type=str, default='config/db_resnet50.yaml', help='DBnet config')
parser.add_argument('--reg_config', type=str, default='config/master.yaml', help='MASTER config')
parser.add_argument('--det_weight', type=str, default='saved/db_pretrain.pth', help='DBnet weight')
parser.add_argument('--reg_weight', type=str, default='saved/master_pretrain.pth', help='MASTER weight')
args = parser.parse_args()
return args
app = FastAPI()
@app.get('/predict')
def predict(image_url: str):
try:
img = download_image(image_url)
except Exception:
result = Result(url=image_url, status_code=400, description='Cant download image..')
response = return_response(result)
return response
start = time.time()
text = model.get_result(img)
text = '\n'.join(text)
end = time.time() - start
result = Result(url=image_url, status_code=200, latency=end, description='Get text sucessfull..', text=text)
response = return_response(result)
return response
if __name__ == '__main__':
opt = parse_args()
load_time = time.time()
model = OCR(opt.det_model, opt.reg_model, opt.det_config, opt.reg_config, opt.det_weight, opt.reg_weight)
print('load model time', time.time() - load_time)
uvicorn.run(app, port=8000, host="127.0.0.1", reload=False)