Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Support GOT-OCR2_0 #2458

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,13 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y opencc
${{ env.SELF_HOST_PYTHON }} -m pip uninstall -y "faster_whisper"
${{ env.SELF_HOST_PYTHON }} -m pip install -U accelerate
${{ env.SELF_HOST_PYTHON }} -m pip install -U verovio
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_got_ocr2.py && \
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_whisper.py && \
Expand All @@ -203,6 +207,6 @@ jobs:
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/audio/tests xinference
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/image/tests/test_got_ocr2.py --ignore xinference/model/audio/tests xinference
fi
working-directory: .
14 changes: 11 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ all =
llama-cpp-python>=0.2.25,!=0.2.58
transformers>=4.43.2
torch>=2.0.0 # >=2.0 For CosyVoice
accelerate>=0.27.2
accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
protobuf
einops
tiktoken
tiktoken>=0.6.0
sentence-transformers>=3.1.0
vllm>=0.2.6 ; sys_platform=='linux'
diffusers>=0.30.0
Expand Down Expand Up @@ -131,6 +131,8 @@ all =
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
verovio>=4.3.1 # For got_ocr2
accelerate>=0.28.0 # For got_ocr2
intel =
torch==2.1.0a0
intel_extension_for_pytorch==2.1.10+xpu
Expand All @@ -139,7 +141,7 @@ llama_cpp =
transformers =
transformers>=4.43.2
torch
accelerate>=0.27.2
accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
Expand Down Expand Up @@ -174,6 +176,12 @@ image =
diffusers>=0.30.0 # fix conflict with matcha-tts
controlnet_aux
deepcache
verovio>=4.3.1 # For got_ocr2
transformers>=4.37.2 # For got_ocr2
tiktoken>=0.6.0 # For got_ocr2
accelerate>=0.28.0 # For got_ocr2
torch # For got_ocr2
torchvision # For got_ocr2
video =
diffusers>=0.30.0
imageio-ffmpeg
Expand Down
48 changes: 48 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/images/ocr",
self.create_ocr,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None
),
)
# SD WebUI API
self._router.add_api_route(
"/sdapi/v1/options",
Expand Down Expand Up @@ -1754,6 +1764,44 @@ async def create_inpainting(
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_ocr(
self,
model: str = Form(...),
image: UploadFile = File(media_type="application/octet-stream"),
kwargs: Optional[str] = Form(None),
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
im = Image.open(image.file)
text = await model_ref.ocr(
image=im,
**parsed_kwargs,
)
return Response(content=text, media_type="text/plain")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_flexible_infer(self, request: Request) -> Response:
payload = await request.json()

Expand Down
19 changes: 19 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,25 @@ def inpainting(
response_data = response.json()
return response_data

def ocr(self, image: Union[str, bytes], **kwargs):
url = f"{self._base_url}/v1/images/ocr"
params = {
"model": self._model_uid,
"kwargs": json.dumps(kwargs),
}
files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("image", ("image", image, "application/octet-stream")))
response = requests.post(url, files=files, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to ocr the images, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data


class RESTfulVideoModelHandle(RESTfulModelHandle):
def text_to_video(
Expand Down
19 changes: 19 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,25 @@ async def inpainting(
f"Model {self._model.model_spec} is not for creating image."
)

@log_async(
logger=logger,
ignore_kwargs=["image"],
)
async def ocr(
self,
image: "PIL.Image",
*args,
**kwargs,
):
if hasattr(self._model, "ocr"):
return await self._call_wrapper_json(
self._model.ocr,
image,
*args,
**kwargs,
)
raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")

@request_limit
@log_async(logger=logger, ignore_kwargs=["image"])
async def infer(
Expand Down
7 changes: 4 additions & 3 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ peft
opencv-contrib-python-headless

# all
transformers>=4.34.1
accelerate>=0.27.2
transformers>=4.43.2
accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
protobuf
einops
tiktoken
tiktoken>=0.6.0
sentence-transformers>=3.1.0
diffusers>=0.30.0
controlnet_aux
Expand Down Expand Up @@ -75,6 +75,7 @@ qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
deepcache # for sd
verovio>=4.3.1 # For got_ocr2

# sglang
outlines>=0.0.44
Expand Down
5 changes: 3 additions & 2 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ passlib[bcrypt]
aioprometheus[starlette]>=23.12.0
nvidia-ml-py
async-timeout
transformers>=4.34.1
accelerate>=0.20.3
transformers>=4.43.2
accelerate>=0.28.0
sentencepiece
transformers_stream_generator
bitsandbytes
Expand Down Expand Up @@ -69,3 +69,4 @@ ormsgpack # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
verovio>=4.3.1 # For got_ocr2
37 changes: 35 additions & 2 deletions xinference/model/image/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import logging
import os
from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple
from typing import Dict, List, Literal, Optional, Tuple, Union

from ...constants import XINFERENCE_CACHE_DIR
from ...types import PeftModelConfig
from ..core import CacheableModelSpec, ModelDescription
from ..utils import valid_model_revision
from .ocr.got_ocr2 import GotOCR2Model
from .stable_diffusion.core import DiffusionModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -180,6 +181,28 @@ def get_cache_status(
return valid_model_revision(meta_path, model_spec.model_revision)


def create_ocr_model_instance(
subpool_addr: str,
devices: List[str],
model_uid: str,
model_spec: ImageModelFamilyV1,
model_path: Optional[str] = None,
**kwargs,
) -> Tuple[GotOCR2Model, ImageModelDescription]:
if not model_path:
model_path = cache(model_spec)
model = GotOCR2Model(
model_uid,
model_path,
model_spec=model_spec,
**kwargs,
)
model_description = ImageModelDescription(
subpool_addr, devices, model_spec, model_path=model_path
)
return model, model_description


def create_image_model_instance(
subpool_addr: str,
devices: List[str],
Expand All @@ -189,8 +212,18 @@ def create_image_model_instance(
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
model_path: Optional[str] = None,
**kwargs,
) -> Tuple[DiffusionModel, ImageModelDescription]:
) -> Tuple[Union[DiffusionModel, GotOCR2Model], ImageModelDescription]:
model_spec = match_diffusion(model_name, download_hub)
if model_spec.model_ability and "ocr" in model_spec.model_ability:
return create_ocr_model_instance(
subpool_addr=subpool_addr,
devices=devices,
model_uid=model_uid,
model_name=model_name,
model_spec=model_spec,
model_path=model_path,
**kwargs,
)
controlnet = kwargs.get("controlnet")
# Handle controlnet
if controlnet is not None:
Expand Down
9 changes: 9 additions & 0 deletions xinference/model/image/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,14 @@
"model_ability": [
"inpainting"
]
},
{
"model_name": "GOT-OCR2_0",
"model_family": "ocr",
"model_id": "stepfun-ai/GOT-OCR2_0",
"model_revision": "cf6b7386bc89a54f09785612ba74cb12de6fa17c",
"model_ability": [
"ocr"
]
}
]
10 changes: 10 additions & 0 deletions xinference/model/image/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,15 @@
"model_revision": "62134b9d8e703b5d6f74f1534457287a8bba77ef"
}
]
},
{
"model_name": "GOT-OCR2_0",
"model_family": "ocr",
"model_id": "stepfun-ai/GOT-OCR2_0",
"model_revision": "master",
"model_hub": "modelscope",
"model_ability": [
"ocr"
]
}
]
13 changes: 13 additions & 0 deletions xinference/model/image/ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading