Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: device に関して API と音声ライブラリを分離 #1250

Merged
merged 5 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
"""エンジンの情報機能を提供する API Router"""

import json
from typing import Self

from fastapi import APIRouter, HTTPException, Response
from pydantic import BaseModel, Field

from voicevox_engine import __version__
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.engine_manifest.EngineManifest import EngineManifest
from voicevox_engine.model import SupportedDevicesInfo


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(title="CPUに対応しているか")
cuda: bool = Field(title="CUDA(Nvidia GPU)に対応しているか")
dml: bool = Field(title="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか")

@classmethod
def generate_from(cls, device_support: DeviceSupport) -> Self:
"""`DeviceSupport` インスタンスからこのインスタンスを生成する。"""
return cls(
cpu=device_support.cpu,
cuda=device_support.cuda,
dml=device_support.dml,
)


def generate_engine_info_router(
Expand All @@ -32,15 +53,12 @@ async def core_versions() -> Response:
@router.get(
"/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"]
)
def supported_devices(core_version: str | None = None) -> Response:
def supported_devices(core_version: str | None = None) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
supported_devices = core_manager.get_core(core_version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return Response(
content=supported_devices,
media_type="application/json",
)
return SupportedDevicesInfo.generate_from(supported_devices)

@router.get("/engine_manifest", response_model=EngineManifest, tags=["その他"])
async def engine_manifest() -> EngineManifest:
Expand Down
25 changes: 21 additions & 4 deletions voicevox_engine/core/core_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import threading
from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray
Expand All @@ -7,6 +9,15 @@
from .core_wrapper import CoreWrapper, OldCoreError


@dataclass(frozen=True)
class DeviceSupport:
"""音声ライブラリのデバイス利用可否"""

cpu: bool
cuda: bool # CUDA (Nvidia GPU)
dml: bool # DirectML (Nvidia GPU/Radeon GPU等)


class CoreAdapter:
"""
コアのアダプター。
Expand All @@ -28,13 +39,19 @@ def speakers(self) -> str:
return self.core.metas()

@property
def supported_devices(self) -> str | None:
def supported_devices(self) -> DeviceSupport | None:
"""デバイスサポート情報(None: 情報無し)"""
try:
supported_devices = self.core.supported_devices()
supported_devices = json.loads(self.core.supported_devices())
assert isinstance(supported_devices, dict)
device_support = DeviceSupport(
cpu=supported_devices["cpu"],
cuda=supported_devices["cuda"],
dml=supported_devices["dml"],
)
except OldCoreError:
supported_devices = None
return supported_devices
device_support = None
return device_support

def initialize_style_id_synthesis(
self, style_id: StyleId, skip_reinit: bool
Expand Down
10 changes: 0 additions & 10 deletions voicevox_engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,6 @@ class WordTypes(str, Enum):
SUFFIX = "SUFFIX"


class SupportedDevicesInfo(BaseModel):
"""
対応しているデバイスの情報
"""

cpu: bool = Field(title="CPUに対応しているか")
cuda: bool = Field(title="CUDA(Nvidia GPU)に対応しているか")
dml: bool = Field(title="DirectML(Nvidia GPU/Radeon GPU等)に対応しているか")


class SupportedFeaturesInfo(BaseModel):
"""
エンジンの機能の情報
Expand Down