Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: GET /supported_devices API を tts_pipeline router へ移動 #1444

Merged
merged 2 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]:
generate_library_router(library_manager, verify_mutability_allowed)
)
app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed))
app.include_router(
generate_engine_info_router(core_version_list, tts_engines, engine_manifest)
)
app.include_router(generate_engine_info_router(core_version_list, engine_manifest))
app.include_router(
generate_setting_router(
setting_loader, engine_manifest.brand_name, verify_mutability_allowed
Expand Down
42 changes: 2 additions & 40 deletions voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
"""エンジンの情報機能を提供する API Router"""

from typing import Self

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from fastapi import APIRouter

from voicevox_engine import __version__
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.engine_manifest import EngineManifest
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(description="CPUに対応しているか")
cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか")
dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか")

@classmethod
def generate_from(cls, device_support: DeviceSupport) -> Self:
"""`DeviceSupport` インスタンスからこのインスタンスを生成する。"""
return cls(
cpu=device_support.cpu,
cuda=device_support.cuda,
dml=device_support.dml,
)


def generate_engine_info_router(
core_version_list: list[str],
tts_engine_manager: TTSEngineManager,
engine_manifest_data: EngineManifest,
core_version_list: list[str], engine_manifest_data: EngineManifest
) -> APIRouter:
"""エンジン情報 API Router を生成する"""
router = APIRouter(tags=["その他"])
Expand All @@ -49,17 +22,6 @@ async def core_versions() -> list[str]:
"""利用可能なコアのバージョン一覧を取得します。"""
return core_version_list

@router.get("/supported_devices")
def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
version = core_version or LATEST_VERSION
supported_devices = tts_engine_manager.get_engine(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)

@router.get("/engine_manifest")
async def engine_manifest() -> EngineManifest:
"""エンジンマニフェストを取得します。"""
Expand Down
33 changes: 32 additions & 1 deletion voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import zipfile
from tempfile import NamedTemporaryFile, TemporaryFile
from typing import Annotated
from typing import Annotated, Self

import soundfile
from fastapi import APIRouter, HTTPException, Query, Request
Expand All @@ -15,6 +15,7 @@
CancellableEngine,
CancellableEngineInternalError,
)
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.model import AudioQuery
from voicevox_engine.preset.preset_manager import (
Expand Down Expand Up @@ -63,6 +64,25 @@ def __init__(self, err: ParseKanaError):
super().__init__(text=err.text, error_name=err.errname, error_args=err.kwargs)


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(description="CPUに対応しているか")
cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか")
dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか")

@classmethod
def generate_from(cls, device_support: DeviceSupport) -> Self:
"""`DeviceSupport` インスタンスからこのインスタンスを生成する。"""
return cls(
cpu=device_support.cpu,
cuda=device_support.cuda,
dml=device_support.dml,
)


def generate_tts_pipeline_router(
tts_engines: TTSEngineManager,
preset_manager: PresetManager,
Expand Down Expand Up @@ -543,4 +563,15 @@ def is_initialized_speaker(
engine = tts_engines.get_engine(version)
return engine.is_synthesis_initialized(style_id)

@router.get("/supported_devices", tags=["その他"])
def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
version = core_version or LATEST_VERSION
supported_devices = tts_engines.get_engine(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)

return router