Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: モーフィング機能を API Router でモジュール化 #1194

Merged
merged 3 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 4 additions & 130 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
import sys
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from functools import lru_cache
from io import BytesIO, TextIOWrapper
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Annotated, Any, Optional

import soundfile
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi import Path as FAPath
Expand All @@ -22,16 +19,15 @@
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from fastapi.templating import Jinja2Templates
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import FileResponse

from voicevox_engine import __version__
from voicevox_engine.app.dependencies import (
check_disabled_mutable_api,
deprecated_mutable_api,
)
from voicevox_engine.app.routers import (
morphing,
preset,
setting,
speaker,
Expand All @@ -45,25 +41,14 @@
from voicevox_engine.engine_manifest.EngineManifestLoader import EngineManifestLoader
from voicevox_engine.library_manager import LibraryManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.metas.MetasStore import MetasStore, construct_lookup
from voicevox_engine.metas.MetasStore import MetasStore
from voicevox_engine.model import (
AudioQuery,
BaseLibraryInfo,
DownloadableLibraryInfo,
InstalledLibraryInfo,
MorphableTargetInfo,
StyleIdNotFoundError,
SupportedDevicesInfo,
VvlibManifest,
)
from voicevox_engine.morphing import (
get_morphable_targets,
is_synthesis_morphing_permitted,
synthesis_morphing,
)
from voicevox_engine.morphing import (
synthesis_morphing_parameter as _synthesis_morphing_parameter,
)
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.Setting import CorsPolicyMode
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
Expand All @@ -73,7 +58,7 @@
)
from voicevox_engine.user_dict.user_dict import update_dict
from voicevox_engine.utility.core_version_utility import get_latest_core_version
from voicevox_engine.utility.path_utility import delete_file, engine_root, get_save_dir
from voicevox_engine.utility.path_utility import engine_root, get_save_dir
from voicevox_engine.utility.run_utility import decide_boolean_from_env


Expand Down Expand Up @@ -216,11 +201,6 @@ async def block_origin_middleware(
variable_end_string="<JINJA_POST>",
)

# キャッシュを有効化
# モジュール側でlru_cacheを指定するとキャッシュを制御しにくいため、HTTPサーバ側で指定する
# TODO: キャッシュを管理するモジュール側API・HTTP側APIを用意する
synthesis_morphing_parameter = lru_cache(maxsize=4)(_synthesis_morphing_parameter)

# @app.on_event("startup")
# async def start_catch_disconnection():
# if cancellable_engine is not None:
Expand Down Expand Up @@ -248,113 +228,7 @@ def get_core(core_version: Optional[str]) -> CoreAdapter:
)
)

@app.post(
"/morphable_targets",
response_model=list[dict[str, MorphableTargetInfo]],
tags=["音声合成"],
summary="指定したスタイルに対してエンジン内の話者がモーフィングが可能か判定する",
)
def morphable_targets(
base_style_ids: list[StyleId], core_version: str | None = None
) -> list[dict[str, MorphableTargetInfo]]:
"""
指定されたベーススタイルに対してエンジン内の各話者がモーフィング機能を利用可能か返します。
モーフィングの許可/禁止は`/speakers`の`speaker.supported_features.synthesis_morphing`に記載されています。
プロパティが存在しない場合は、モーフィングが許可されているとみなします。
返り値のスタイルIDはstring型なので注意。
"""
core = get_core(core_version)

try:
speakers = metas_store.load_combined_metas(core=core)
morphable_targets = get_morphable_targets(
speakers=speakers, base_style_ids=base_style_ids
)
# jsonはint型のキーを持てないので、string型に変換する
return [
{str(k): v for k, v in morphable_target.items()}
for morphable_target in morphable_targets
]
except StyleIdNotFoundError as e:
raise HTTPException(
status_code=404,
detail=f"該当するスタイル(style_id={e.style_id})が見つかりません",
)

@app.post(
"/synthesis_morphing",
response_class=FileResponse,
responses={
200: {
"content": {
"audio/wav": {"schema": {"type": "string", "format": "binary"}}
},
}
},
tags=["音声合成"],
summary="2種類のスタイルでモーフィングした音声を合成する",
)
def _synthesis_morphing(
query: AudioQuery,
base_style_id: Annotated[StyleId, Query(alias="base_speaker")],
target_style_id: Annotated[StyleId, Query(alias="target_speaker")],
morph_rate: Annotated[float, Query(ge=0.0, le=1.0)],
core_version: str | None = None,
) -> FileResponse:
"""
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = get_engine(core_version)
core = get_core(core_version)

try:
speakers = metas_store.load_combined_metas(core=core)
speaker_lookup = construct_lookup(speakers=speakers)
is_permitted = is_synthesis_morphing_permitted(
speaker_lookup, base_style_id, target_style_id
)
if not is_permitted:
raise HTTPException(
status_code=400,
detail="指定されたスタイルペアでのモーフィングはできません",
)
except StyleIdNotFoundError as e:
raise HTTPException(
status_code=404,
detail=f"該当するスタイル(style_id={e.style_id})が見つかりません",
)

# 生成したパラメータはキャッシュされる
morph_param = synthesis_morphing_parameter(
engine=engine,
core=core,
query=query,
base_style_id=base_style_id,
target_style_id=target_style_id,
)

morph_wave = synthesis_morphing(
morph_param=morph_param,
morph_rate=morph_rate,
output_fs=query.outputSamplingRate,
output_stereo=query.outputStereo,
)

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f,
data=morph_wave,
samplerate=query.outputSamplingRate,
format="WAV",
)

return FileResponse(
f.name,
media_type="audio/wav",
background=BackgroundTask(delete_file, f.name),
)

app.include_router(morphing.generate_router(get_engine, get_core, metas_store))
app.include_router(preset.generate_router(preset_manager))

@app.get("/version", tags=["その他"])
Expand Down
148 changes: 148 additions & 0 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""モーフィング機能を提供する API Router"""

from functools import lru_cache
from tempfile import NamedTemporaryFile
from typing import Annotated, Callable

import soundfile
from fastapi import APIRouter, HTTPException, Query
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.core.core_adapter import CoreAdapter
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.metas.MetasStore import MetasStore, construct_lookup
from voicevox_engine.model import AudioQuery, MorphableTargetInfo, StyleIdNotFoundError
from voicevox_engine.morphing import (
get_morphable_targets,
is_synthesis_morphing_permitted,
synthesis_morphing,
)
from voicevox_engine.morphing import (
synthesis_morphing_parameter as _synthesis_morphing_parameter,
)
from voicevox_engine.tts_pipeline.tts_engine import TTSEngine
from voicevox_engine.utility.path_utility import delete_file

# キャッシュを有効化
# モジュール側でlru_cacheを指定するとキャッシュを制御しにくいため、HTTPサーバ側で指定する
# TODO: キャッシュを管理するモジュール側API・HTTP側APIを用意する
synthesis_morphing_parameter = lru_cache(maxsize=4)(_synthesis_morphing_parameter)


def generate_router(
get_engine: Callable[[str | None], TTSEngine],
get_core: Callable[[str | None], CoreAdapter],
metas_store: MetasStore,
) -> APIRouter:
"""モーフィング API Router を生成する"""
router = APIRouter()

@router.post(
"/morphable_targets",
response_model=list[dict[str, MorphableTargetInfo]],
tags=["音声合成"],
summary="指定したスタイルに対してエンジン内の話者がモーフィングが可能か判定する",
)
def morphable_targets(
base_style_ids: list[StyleId], core_version: str | None = None
) -> list[dict[str, MorphableTargetInfo]]:
"""
指定されたベーススタイルに対してエンジン内の各話者がモーフィング機能を利用可能か返します。
モーフィングの許可/禁止は`/speakers`の`speaker.supported_features.synthesis_morphing`に記載されています。
プロパティが存在しない場合は、モーフィングが許可されているとみなします。
返り値のスタイルIDはstring型なので注意。
"""
core = get_core(core_version)

try:
speakers = metas_store.load_combined_metas(core=core)
morphable_targets = get_morphable_targets(
speakers=speakers, base_style_ids=base_style_ids
)
# jsonはint型のキーを持てないので、string型に変換する
return [
{str(k): v for k, v in morphable_target.items()}
for morphable_target in morphable_targets
]
except StyleIdNotFoundError as e:
raise HTTPException(
status_code=404,
detail=f"該当するスタイル(style_id={e.style_id})が見つかりません",
)

@router.post(
"/synthesis_morphing",
response_class=FileResponse,
responses={
200: {
"content": {
"audio/wav": {"schema": {"type": "string", "format": "binary"}}
},
}
},
tags=["音声合成"],
summary="2種類のスタイルでモーフィングした音声を合成する",
)
def _synthesis_morphing(
query: AudioQuery,
base_style_id: Annotated[StyleId, Query(alias="base_speaker")],
target_style_id: Annotated[StyleId, Query(alias="target_speaker")],
morph_rate: Annotated[float, Query(ge=0.0, le=1.0)],
core_version: str | None = None,
) -> FileResponse:
"""
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = get_engine(core_version)
core = get_core(core_version)

try:
speakers = metas_store.load_combined_metas(core=core)
speaker_lookup = construct_lookup(speakers=speakers)
is_permitted = is_synthesis_morphing_permitted(
speaker_lookup, base_style_id, target_style_id
)
if not is_permitted:
raise HTTPException(
status_code=400,
detail="指定されたスタイルペアでのモーフィングはできません",
)
except StyleIdNotFoundError as e:
raise HTTPException(
status_code=404,
detail=f"該当するスタイル(style_id={e.style_id})が見つかりません",
)

# 生成したパラメータはキャッシュされる
morph_param = synthesis_morphing_parameter(
engine=engine,
core=core,
query=query,
base_style_id=base_style_id,
target_style_id=target_style_id,
)

morph_wave = synthesis_morphing(
morph_param=morph_param,
morph_rate=morph_rate,
output_fs=query.outputSamplingRate,
output_stereo=query.outputStereo,
)

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
file=f,
data=morph_wave,
samplerate=query.outputSamplingRate,
format="WAV",
)

return FileResponse(
f.name,
media_type="audio/wav",
background=BackgroundTask(delete_file, f.name),
)

return router
Loading