Skip to content

Commit

Permalink
support model agent for idp sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
hrfng committed Aug 30, 2024
1 parent f21d453 commit 494ffa1
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 21 deletions.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ xlrd==2.0.1
uvicorn
fastapi
orjson

# client
tritonclient[http]==2.41.0
26 changes: 25 additions & 1 deletion src/bisheng_unstructured/models/layout_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import base64
import copy
import json

import numpy as np
import requests
import tritonclient.http as httpclient


# Layout Agent Version 0.1, update at 2023.08.18
class LayoutAgent(object):
class LayoutAgentV0(object):
def __init__(self, *args, **kwargs):
self.ep = kwargs.get("layout_ep")
self.client = requests.Session()
Expand All @@ -25,3 +28,24 @@ def predict(self, inp):
raise Exception(f"timeout in layout predict")
except Exception as e:
raise Exception(f"exception in layout predict: [{e}]")


class LayoutAgent:
def __init__(self, *args, **kwargs):
ep_parts = kwargs.get("layout_ep").split("/")
self.model = ep_parts[-2]
server_url = ep_parts[2]
self.client = httpclient.InferenceServerClient(url=server_url, verbose=False)

def predict(self, inp):
# b64_image = base64.b64encode(open(image_file, 'rb').read()).decode('utf-8')
input0_data = np.asarray([json.dumps(inp)], dtype=np.object_)
inputs = [httpclient.InferInput("INPUT", [1], "BYTES")]
inputs[0].set_data_from_numpy(input0_data)
outputs = [httpclient.InferRequestedOutput("OUTPUT")]
try:
response = self.client.infer(self.model, inputs, request_id=str(1), outputs=outputs)
output_data = json.loads(response.as_numpy("OUTPUT")[0].decode("utf-8"))
except Exception as e:
raise Exception(f"exception in layout predict: [{e}]")
return output_data
34 changes: 16 additions & 18 deletions src/bisheng_unstructured/models/ocr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from PIL import Image

from bisheng_unstructured.config.settings import settings

from bisheng_unstructured.models.common import (
bbox_overlap,
draw_polygon,
Expand All @@ -30,21 +29,21 @@
},
"scene_mapping": {
"print": {
"det": "general_text_det_mrcnn_v2.0",
"recog": "transformer-blank-v0.2-faster",
"det": "general_text_det_v2.0",
"recog": "general_text_reg_nb_v1.0_faster",
},
"hand": {
"det": "general_text_det_mrcnn_v2.0",
"recog": "transformer-hand-v1.16-faster",
"det": "general_text_det_v2.0",
"recog": "general_text_reg_nb_v1.0_faster",
},
"print_recog": {
"recog": "transformer-blank-v0.2-faster",
"recog": "general_text_reg_nb_v1.0_faster",
},
"hand_recog": {
"recog": "transformer-hand-v1.16-faster",
"recog": "general_text_reg_nb_v1.0_faster",
},
"det": {
"det": "general_text_det_mrcnn_v2.0",
"det": "general_text_det_v2.0",
},
},
}
Expand All @@ -53,7 +52,6 @@
# OCR Agent Version 0.1, update at 2023.08.18
# - add predict_with_mask support recog with embedding formula, 2024.01.16
class OCRAgent(object):

def __init__(self, **kwargs):
self.ep = kwargs.get("ocr_model_ep")
self.client = requests.Session()
Expand Down Expand Up @@ -118,10 +116,8 @@ def predict_with_mask(self, img0, mf_out, scene="print", **kwargs):

xmin, ymin = max(0, int(box[0][0]) - 1), max(0, int(box[0][1]) - 1)
xmax, ymax = (
min(img0.size[0],
int(box[2][0]) + 1),
min(img0.size[1],
int(box[2][1]) + 1),
min(img0.size[0], int(box[2][0]) + 1),
min(img0.size[1], int(box[2][1]) + 1),
)
img[ymin:ymax, xmin:xmax, :] = 255

Expand Down Expand Up @@ -151,11 +147,13 @@ def predict_with_mask(self, img0, mf_out, scene="print", **kwargs):
emb_bbox = [bb[0], bb[1], bb[4], bb[5]]
bbox_iou = bbox_overlap(hori_bbox, emb_bbox)
if bbox_iou > EMB_BBOX_THREHOLD:
embed_mfs.append({
"position": emb_bbox,
"text": box_info["text"],
"type": box_info["type"],
})
embed_mfs.append(
{
"position": emb_bbox,
"text": box_info["text"],
"type": box_info["type"],
}
)

ocr_boxes = split_line_image(hori_bbox, embed_mfs)
text_bboxes.extend(ocr_boxes)
Expand Down
8 changes: 8 additions & 0 deletions src/bisheng_unstructured/models/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

## models endpoint for idp sdk

- ocr model endpoint: http://{host:port}/v2/idp/idp_app/infer
- layout model endpoint: http://{host:port}/v2/models/elem_layout_v1/infer
- table det model endpoint: http://{host:port}/v2/models/elem_table_detect_v1/infer
- table rowcol model endpoint: http://{host:port}/v2/models/elem_table_rowcol_detect_v1/infer
- table cell model endpoint: http://{host:port}/v2/models/elem_table_cell_detect_v1/infer
80 changes: 78 additions & 2 deletions src/bisheng_unstructured/models/table_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import base64
import copy
import json

import numpy as np
import requests
import tritonclient.http as httpclient


# Table Agent Version 0.1, update at 2023.08.18
class TableAgent(object):
class TableAgentV0(object):
def __init__(self, **kwargs):
cell_model_ep = kwargs.get("cell_model_ep")
rowcol_model_ep = kwargs.get("rowcol_model_ep")
Expand Down Expand Up @@ -39,7 +42,7 @@ def predict(self, inp):


# TableDet Agent Version 0.1, update at 2023.08.31
class TableDetAgent(object):
class TableDetAgentV0(object):
def __init__(self, **kwargs):
self.ep = kwargs.get("table_model_ep")
self.client = requests.Session()
Expand All @@ -57,3 +60,76 @@ def predict(self, inp):
raise Exception(f"timeout in table det predict")
except Exception as e:
raise Exception(f"exception in table det predict: [{e}]")


class TableAgent(object):
def __init__(self, **kwargs):
ep_parts = kwargs.get("cell_model_ep").split("/")
server_url = ep_parts[2]
self.cell_model = ep_parts[-2]
self.cell_client = httpclient.InferenceServerClient(url=server_url, verbose=False)

ep_parts = kwargs.get("rowcol_model_ep").split("/")
server_url = ep_parts[2]
self.rowcol_model = ep_parts[-2]
self.rowcol_client = httpclient.InferenceServerClient(url=server_url, verbose=False)

self.timeout = kwargs.get("timeout", 60)
self.params = {
"sep_char": " ",
"longer_edge_size": None,
"padding": False,
}

def predict(self, inp):
scene = inp.pop("scene", "rowcol")
if scene == "rowcol":
client, model = self.rowcol_client, self.rowcol_model
else:
client, model = self.cell_client, self.cell_model

payload = copy.deepcopy(self.params)
payload.update(inp)

# ocr_result = json.dumps(ocr_result)
# table_bbox = table_result["bboxes"]
# b64_image = base64.b64encode(open(image_file, 'rb').read()).decode('utf-8')
# payload = {'b64_image': b64_image, 'table_bboxes': table_bbox, 'ocr_result': ocr_result}

input0_data = np.asarray([json.dumps(payload)], dtype=np.object_)
# print(input0_data)
inputs = [httpclient.InferInput("INPUT", [1], "BYTES")]
inputs[0].set_data_from_numpy(input0_data)
outputs = [httpclient.InferRequestedOutput("OUTPUT")]
try:
response = client.infer(model, inputs, request_id=str(1), outputs=outputs)
print("response", response)
output_data = json.loads(response.as_numpy("OUTPUT")[0].decode("utf-8"))
except Exception as e:
raise Exception(f"exception in table structure predict: [{e}]")

return output_data


class TableDetAgent(object):
def __init__(self, **kwargs):
ep_parts = kwargs.get("table_model_ep").split("/")
server_url = ep_parts[2]
self.model = ep_parts[-2]
self.client = httpclient.InferenceServerClient(url=server_url, verbose=False)
self.timeout = kwargs.get("timeout", 60)

def predict(self, inp):
# b64data = base64.b64encode(open(image_file, 'rb').read()).decode('utf-8')
input0_data = np.asarray([json.dumps(inp)], dtype=np.object_)

inputs = [httpclient.InferInput("INPUT", [1], "BYTES")]
inputs[0].set_data_from_numpy(input0_data)
outputs = [httpclient.InferRequestedOutput("OUTPUT")]
try:
response = self.client.infer(self.model, inputs, request_id=str(1), outputs=outputs)
output_data = json.loads(response.as_numpy("OUTPUT")[0].decode("utf-8"))
except Exception as e:
raise Exception(f"exception in table det predict: [{e}]")

return output_data
56 changes: 56 additions & 0 deletions tests/test_idp_models_sdk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# flake8: noqa
import base64
import hashlib
import json
import os

import pytest

from bisheng_unstructured.models.layout_agent import LayoutAgent
from bisheng_unstructured.models.ocr_agent import OCRAgent
from bisheng_unstructured.models.table_agent import TableAgent, TableDetAgent

configs = dict(
layout_ep="http://192.168.106.20:10502/v2/models/elem_layout_v1/infer",
cell_model_ep="http://192.168.106.20:10502/v2/models/elem_table_cell_detect_v1/infer",
rowcol_model_ep="http://192.168.106.20:10502/v2/models/elem_table_rowcol_detect_v1/infer",
table_model_ep="http://192.168.106.20:10502/v2/models/elem_table_detect_v1/infer",
ocr_model_ep="http://192.168.106.20:10502/v2/idp/idp_app/infer",
)


# @pytest.mark.skip
def test_layout():
layout_agent = LayoutAgent(**configs)

image_file = "data/001.png"
b64_image = base64.b64encode(open(image_file, "rb").read()).decode("utf-8")
inp = {"b64_image": b64_image}
result = layout_agent.predict(inp)
print("result", result)


# @pytest.mark.skip
def test_ocr():
ocr_agent = OCRAgent(**configs)

image_file = "data/001.png"
b64_image = base64.b64encode(open(image_file, "rb").read()).decode("utf-8")
inp = {"b64_image": b64_image}
result = ocr_agent.predict(inp)
print("result", result)


def test_table_det():
table_det_agent = TableDetAgent(**configs)
table_agent = TableAgent(**configs)
ocr_agent = OCRAgent(**configs)

image_file = "data/001.png"
b64_image = base64.b64encode(open(image_file, "rb").read()).decode("utf-8")
inp = {"b64_image": b64_image}
table_bboxes = table_det_agent.predict(inp)["bboxes"]
ocr_result = json.dumps(ocr_agent.predict(inp)["result"]["ocr_result"])
inp = {"b64_image": b64_image, "table_bboxes": table_bboxes, "ocr_result": ocr_result}
table_result = table_agent.predict(inp)
print("table_result", table_result)

0 comments on commit 494ffa1

Please sign in to comment.