Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Support CogVideoX video model #2049

Merged
merged 15 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ all =
sentence-transformers>=2.7.0
vllm>=0.2.6 ; sys_platform=='linux'
diffusers>=0.25.0 # fix conflict with matcha-tts
imageio-ffmpeg # For video
controlnet_aux
orjson
auto-gptq ; sys_platform!='darwin'
Expand Down Expand Up @@ -158,6 +159,9 @@ rerank =
image =
diffusers>=0.25.0 # fix conflict with matcha-tts
controlnet_aux
video =
diffusers
imageio-ffmpeg
audio =
funasr
omegaconf~=2.3.0
Expand Down
52 changes: 52 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
CreateCompletion,
ImageList,
PeftModelConfig,
VideoList,
max_tokens_field,
)
from .oauth2.auth_service import AuthService
Expand Down Expand Up @@ -123,6 +124,14 @@ class TextToImageRequest(BaseModel):
user: Optional[str] = None


class TextToVideoRequest(BaseModel):
model: str
prompt: Union[str, List[str]] = Field(description="The input to embed.")
n: Optional[int] = 1
kwargs: Optional[str] = None
user: Optional[str] = None


class SpeechRequest(BaseModel):
model: str
input: str
Expand Down Expand Up @@ -512,6 +521,17 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/video/generations",
self.create_videos,
methods=["POST"],
response_model=VideoList,
dependencies=(
[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
Expand Down Expand Up @@ -1546,6 +1566,38 @@ async def create_flexible_infer(self, request: Request) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_videos(self, request: Request) -> Response:
body = TextToVideoRequest.parse_obj(await request.json())
model_uid = body.model
try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
kwargs = json.loads(body.kwargs) if body.kwargs else {}
video_list = await model.text_to_video(
prompt=body.prompt,
n=body.n,
**kwargs,
)
return Response(content=video_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_chat_completion(self, request: Request) -> Response:
raw_body = await request.json()
body = CreateChatCompletion.parse_obj(raw_body)
Expand Down
43 changes: 43 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ImageList,
LlamaCppGenerateConfig,
PytorchGenerateConfig,
VideoList,
)


Expand Down Expand Up @@ -370,6 +371,44 @@ def inpainting(
return response_data


class RESTfulVideoModelHandle(RESTfulModelHandle):
def text_to_video(
self,
prompt: str,
n: int = 1,
**kwargs,
) -> "VideoList":
"""
Creates a video by the input text.

Parameters
----------
prompt: `str` or `List[str]`
The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
n: `int`, defaults to 1
The number of videos to generate per prompt. Must be between 1 and 10.
Returns
-------
VideoList
A list of video objects.
"""
url = f"{self._base_url}/v1/video/generations"
request_body = {
"model": self._model_uid,
"prompt": prompt,
"n": n,
"kwargs": json.dumps(kwargs),
}
response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the video, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data


class RESTfulGenerateModelHandle(RESTfulModelHandle):
def generate(
self,
Expand Down Expand Up @@ -1015,6 +1054,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulAudioModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
elif desc["model_type"] == "video":
return RESTfulVideoModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
elif desc["model_type"] == "flexible":
return RESTfulFlexibleModelHandle(
model_uid, self.base_url, auth_headers=self._headers
Expand Down
1 change: 1 addition & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_xinference_home() -> str:
XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
XINFERENCE_IMAGE_DIR = os.path.join(XINFERENCE_HOME, "image")
XINFERENCE_VIDEO_DIR = os.path.join(XINFERENCE_HOME, "video")
XINFERENCE_AUTH_DIR = os.path.join(XINFERENCE_HOME, "auth")
XINFERENCE_CSG_ENDPOINT = str(
os.environ.get(XINFERENCE_ENV_CSG_ENDPOINT, "https://hub-stg.opencsg.com/")
Expand Down
21 changes: 21 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,27 @@ async def infer(
f"Model {self._model.model_spec} is not for flexible infer."
)

@log_async(logger=logger)
@request_limit
async def text_to_video(
self,
prompt: str,
n: int = 1,
*args,
**kwargs,
):
if hasattr(self._model, "text_to_video"):
return await self._call_wrapper_json(
self._model.text_to_video,
prompt,
n,
*args,
**kwargs,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating video."
)

async def record_metrics(self, name, op, kwargs):
worker_ref = await self._get_worker_ref()
await worker_ref.record_metrics(name, op, kwargs)
37 changes: 37 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from ..model.image import ImageModelFamilyV1
from ..model.llm import LLMFamilyV1
from ..model.rerank import RerankModelSpec
from ..model.video import VideoModelFamilyV1
from .worker import WorkerActor


Expand Down Expand Up @@ -484,6 +485,31 @@ async def _to_audio_model_reg(
res["model_instance_count"] = instance_cnt
return res

async def _to_video_model_reg(
self, model_family: "VideoModelFamilyV1", is_builtin: bool
) -> Dict[str, Any]:
from ..model.video import get_cache_status

instance_cnt = await self.get_instance_count(model_family.model_name)
version_cnt = await self.get_model_version_count(model_family.model_name)

if self.is_local_deployment():
# TODO: does not work when the supervisor and worker are running on separate nodes.
cache_status = get_cache_status(model_family)
res = {
**model_family.dict(),
"cache_status": cache_status,
"is_builtin": is_builtin,
}
else:
res = {
**model_family.dict(),
"is_builtin": is_builtin,
}
res["model_version_count"] = version_cnt
res["model_instance_count"] = instance_cnt
return res

async def _to_flexible_model_reg(
self, model_spec: "FlexibleModelSpec", is_builtin: bool
) -> Dict[str, Any]:
Expand Down Expand Up @@ -602,6 +628,17 @@ def sort_helper(item):
{"model_name": model_spec.model_name, "is_builtin": False}
)

ret.sort(key=sort_helper)
return ret
elif model_type == "video":
from ..model.video import BUILTIN_VIDEO_MODELS

for model_name, family in BUILTIN_VIDEO_MODELS.items():
if detailed:
ret.append(await self._to_video_model_reg(family, is_builtin=True))
else:
ret.append({"model_name": model_name, "is_builtin": True})

ret.sort(key=sort_helper)
return ret
elif model_type == "rerank":
Expand Down
2 changes: 2 additions & 0 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,8 @@ async def _get_model_ability(self, model: Any, model_type: str) -> List[str]:
return ["text_to_image"]
elif model_type == "audio":
return ["audio_to_text"]
elif model_type == "video":
return ["text_to_video"]
elif model_type == "flexible":
return ["flexible"]
else:
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # Fo
openai-whisper # For CosyVoice
boto3>=1.28.55,<1.28.65 # For tensorizer
tensorizer~=2.9.0
imageio-ffmpeg # For video

# sglang
outlines>=0.0.44
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ matcha-tts # For CosyVoice
onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice
openai-whisper # For CosyVoice
imageio-ffmpeg # For video
12 changes: 12 additions & 0 deletions xinference/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def create_model_instance(
from .image.core import create_image_model_instance
from .llm.core import create_llm_model_instance
from .rerank.core import create_rerank_model_instance
from .video.core import create_video_model_instance

if model_type == "LLM":
return create_llm_model_instance(
Expand Down Expand Up @@ -127,6 +128,17 @@ def create_model_instance(
model_path,
**kwargs,
)
elif model_type == "video":
kwargs.pop("trust_remote_code", None)
return create_video_model_instance(
subpool_addr,
devices,
model_uid,
model_name,
download_hub,
model_path,
**kwargs,
)
elif model_type == "flexible":
kwargs.pop("trust_remote_code", None)
return create_flexible_model_instance(
Expand Down
62 changes: 62 additions & 0 deletions xinference/model/video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import json
import os
from itertools import chain

from .core import (
BUILTIN_VIDEO_MODELS,
MODEL_NAME_TO_REVISION,
MODELSCOPE_VIDEO_MODELS,
VIDEO_MODEL_DESCRIPTIONS,
VideoModelFamilyV1,
generate_video_description,
get_cache_status,
get_video_model_descriptions,
)

_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
_model_spec_modelscope_json = os.path.join(
os.path.dirname(__file__), "model_spec_modelscope.json"
)
BUILTIN_VIDEO_MODELS.update(
dict(
(spec["model_name"], VideoModelFamilyV1(**spec))
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
)
)
for model_name, model_spec in BUILTIN_VIDEO_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)

MODELSCOPE_VIDEO_MODELS.update(
dict(
(spec["model_name"], VideoModelFamilyV1(**spec))
for spec in json.load(
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
)
)
)
for model_name, model_spec in MODELSCOPE_VIDEO_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)

# register model description
for model_name, model_spec in chain(
MODELSCOPE_VIDEO_MODELS.items(), BUILTIN_VIDEO_MODELS.items()
):
VIDEO_MODEL_DESCRIPTIONS.update(generate_video_description(model_spec))

del _model_spec_json
del _model_spec_modelscope_json
Loading
Loading