From 761adf1441199282ffaf3d145ecc48ada30e60ee Mon Sep 17 00:00:00 2001 From: tarepan Date: Thu, 27 Jun 2024 06:30:08 +0000 Subject: [PATCH] =?UTF-8?q?refactor:=20`supported=5Fdevices`=20API=20?= =?UTF-8?q?=E3=82=92=20`tts=5Fpipeline`=20router=20=E3=81=B8=E7=A7=BB?= =?UTF-8?q?=E5=8B=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- voicevox_engine/app/application.py | 4 +- voicevox_engine/app/routers/engine_info.py | 42 +-------------------- voicevox_engine/app/routers/tts_pipeline.py | 33 +++++++++++++++- 3 files changed, 35 insertions(+), 44 deletions(-) diff --git a/voicevox_engine/app/application.py b/voicevox_engine/app/application.py index 90611572f..6a082ac33 100644 --- a/voicevox_engine/app/application.py +++ b/voicevox_engine/app/application.py @@ -94,9 +94,7 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]: generate_library_router(library_manager, verify_mutability_allowed) ) app.include_router(generate_user_dict_router(user_dict, verify_mutability_allowed)) - app.include_router( - generate_engine_info_router(core_manager, tts_engines, engine_manifest) - ) + app.include_router(generate_engine_info_router(core_manager, engine_manifest)) app.include_router( generate_setting_router( setting_loader, engine_manifest.brand_name, verify_mutability_allowed diff --git a/voicevox_engine/app/routers/engine_info.py b/voicevox_engine/app/routers/engine_info.py index cfd4a52c6..a7efa5cc4 100644 --- a/voicevox_engine/app/routers/engine_info.py +++ b/voicevox_engine/app/routers/engine_info.py @@ -1,41 +1,14 @@ """エンジンの情報機能を提供する API Router""" -from typing import Self - -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field -from pydantic.json_schema import SkipJsonSchema +from fastapi import APIRouter from voicevox_engine import __version__ -from voicevox_engine.core.core_adapter import DeviceSupport from voicevox_engine.core.core_initializer import CoreManager from voicevox_engine.engine_manifest import EngineManifest -from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager - - -class SupportedDevicesInfo(BaseModel): - """ - 対応しているデバイスの情報 - """ - - cpu: bool = Field(description="CPUに対応しているか") - cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか") - dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか") - - @classmethod - def generate_from(cls, device_support: DeviceSupport) -> Self: - """`DeviceSupport` インスタンスからこのインスタンスを生成する。""" - return cls( - cpu=device_support.cpu, - cuda=device_support.cuda, - dml=device_support.dml, - ) def generate_engine_info_router( - core_manager: CoreManager, - tts_engine_manager: TTSEngineManager, - engine_manifest_data: EngineManifest, + core_manager: CoreManager, engine_manifest_data: EngineManifest ) -> APIRouter: """エンジン情報 API Router を生成する""" router = APIRouter(tags=["その他"]) @@ -50,17 +23,6 @@ async def core_versions() -> list[str]: """利用可能なコアのバージョン一覧を取得します。""" return core_manager.versions() - @router.get("/supported_devices") - def supported_devices( - core_version: str | SkipJsonSchema[None] = None, - ) -> SupportedDevicesInfo: - """対応デバイスの一覧を取得します。""" - version = core_version or core_manager.latest_version() - supported_devices = tts_engine_manager.get_engine(version).supported_devices - if supported_devices is None: - raise HTTPException(status_code=422, detail="非対応の機能です。") - return SupportedDevicesInfo.generate_from(supported_devices) - @router.get("/engine_manifest") async def engine_manifest() -> EngineManifest: """エンジンマニフェストを取得します。""" diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 7b113eb39..dffe14e45 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -2,7 +2,7 @@ import zipfile from tempfile import NamedTemporaryFile, TemporaryFile -from typing import Annotated +from typing import Annotated, Self import soundfile from fastapi import APIRouter, HTTPException, Query, Request @@ -15,6 +15,7 @@ CancellableEngine, CancellableEngineInternalError, ) +from voicevox_engine.core.core_adapter import DeviceSupport from voicevox_engine.core.core_initializer import CoreManager from voicevox_engine.metas.Metas import StyleId from voicevox_engine.model import AudioQuery @@ -63,6 +64,25 @@ def __init__(self, err: ParseKanaError): super().__init__(text=err.text, error_name=err.errname, error_args=err.kwargs) +class SupportedDevicesInfo(BaseModel): + """ + 対応しているデバイスの情報 + """ + + cpu: bool = Field(description="CPUに対応しているか") + cuda: bool = Field(description="CUDA(Nvidia GPU)に対応しているか") + dml: bool = Field(description="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか") + + @classmethod + def generate_from(cls, device_support: DeviceSupport) -> Self: + """`DeviceSupport` インスタンスからこのインスタンスを生成する。""" + return cls( + cpu=device_support.cpu, + cuda=device_support.cuda, + dml=device_support.dml, + ) + + def generate_tts_pipeline_router( tts_engines: TTSEngineManager, core_manager: CoreManager, @@ -544,4 +564,15 @@ def is_initialized_speaker( engine = tts_engines.get_engine(version) return engine.is_synthesis_initialized(style_id) + @router.get("/supported_devices", tags=["その他"]) + def supported_devices( + core_version: str | SkipJsonSchema[None] = None, + ) -> SupportedDevicesInfo: + """対応デバイスの一覧を取得します。""" + version = core_version or core_manager.latest_version() + supported_devices = tts_engines.get_engine(version).supported_devices + if supported_devices is None: + raise HTTPException(status_code=422, detail="非対応の機能です。") + return SupportedDevicesInfo.generate_from(supported_devices) + return router