Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: ユーザー辞書機能を API Router でモジュール化 #1156

Merged
merged 6 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 7 additions & 190 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import re
import sys
import traceback
import zipfile
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
Expand All @@ -25,12 +24,14 @@
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError, parse_obj_as
from pydantic import parse_obj_as
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import FileResponse

from voicevox_engine import __version__
from voicevox_engine.app.dependencies import check_disabled_mutable_api, mutable_api
from voicevox_engine.app.routers import user_dict
from voicevox_engine.cancellable_engine import CancellableEngine
from voicevox_engine.core.core_adapter import CoreAdapter
from voicevox_engine.core.core_initializer import initialize_cores
Expand Down Expand Up @@ -58,9 +59,7 @@
SpeakerInfo,
StyleIdNotFoundError,
SupportedDevicesInfo,
UserDictWord,
VvlibManifest,
WordTypes,
)
from voicevox_engine.morphing import (
get_morphable_targets,
Expand All @@ -80,15 +79,7 @@
TTSEngine,
make_tts_engines_from_cores,
)
from voicevox_engine.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY
from voicevox_engine.user_dict.user_dict import (
apply_word,
delete_word,
import_user_dict,
read_dict,
rewrite_word,
update_dict,
)
from voicevox_engine.user_dict.user_dict import update_dict
from voicevox_engine.utility.connect_base64_waves import (
ConnectBase64WavesException,
connect_base64_waves,
Expand Down Expand Up @@ -219,13 +210,8 @@ async def block_origin_middleware(
status_code=403, content={"detail": "Origin not allowed"}
)

# 許可されていないAPIを無効化する
async def check_disabled_mutable_api() -> None:
if disable_mutable_api:
raise HTTPException(
status_code=403,
detail="エンジンの静的なデータを変更するAPIは無効化されています",
)
if disable_mutable_api:
mutable_api.enable = False

engine_manifest_data = EngineManifestLoader(
engine_root() / "engine_manifest.json", engine_root()
Expand Down Expand Up @@ -1092,176 +1078,7 @@ def is_initialized_speaker(
core = get_core(core_version)
return core.is_initialized_style_id_synthesis(style_id)

@app.get(
"/user_dict",
response_model=dict[str, UserDictWord],
response_description="単語のUUIDとその詳細",
tags=["ユーザー辞書"],
)
def get_user_dict_words() -> dict[str, UserDictWord]:
"""
ユーザー辞書に登録されている単語の一覧を返します。
単語の表層形(surface)は正規化済みの物を返します。
"""
try:
return read_dict()
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="辞書の読み込みに失敗しました。"
)

@app.post(
"/user_dict_word",
response_model=str,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def add_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
accent_type: Annotated[
int, Query(description="アクセント型(音が下がる場所を指す)")
],
word_type: Annotated[
WordTypes | None,
Query(
description="PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか"
),
] = None,
priority: Annotated[
int | None,
Query(
ge=MIN_PRIORITY,
le=MAX_PRIORITY,
description="単語の優先度(0から10までの整数)。数字が大きいほど優先度が高くなる。1から9までの値を指定することを推奨",
),
] = None,
) -> Response:
"""
ユーザー辞書に言葉を追加します。
"""
try:
word_uuid = apply_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
)
return Response(content=word_uuid)
except ValidationError as e:
raise HTTPException(
status_code=422, detail="パラメータに誤りがあります。\n" + str(e)
)
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書への追加に失敗しました。"
)

@app.put(
"/user_dict_word/{word_uuid}",
status_code=204,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def rewrite_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
accent_type: Annotated[
int, Query(description="アクセント型(音が下がる場所を指す)")
],
word_uuid: Annotated[str, FAPath(description="更新する言葉のUUID")],
word_type: Annotated[
WordTypes | None,
Query(
description="PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか"
),
] = None,
priority: Annotated[
int | None,
Query(
ge=MIN_PRIORITY,
le=MAX_PRIORITY,
description="単語の優先度(0から10までの整数)。数字が大きいほど優先度が高くなる。1から9までの値を指定することを推奨。",
),
] = None,
) -> Response:
"""
ユーザー辞書に登録されている言葉を更新します。
"""
try:
rewrite_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_uuid=word_uuid,
word_type=word_type,
priority=priority,
)
return Response(status_code=204)
except HTTPException:
raise
except ValidationError as e:
raise HTTPException(
status_code=422, detail="パラメータに誤りがあります。\n" + str(e)
)
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書の更新に失敗しました。"
)

@app.delete(
"/user_dict_word/{word_uuid}",
status_code=204,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def delete_user_dict_word(
word_uuid: Annotated[str, FAPath(description="削除する言葉のUUID")]
) -> Response:
"""
ユーザー辞書に登録されている言葉を削除します。
"""
try:
delete_word(word_uuid=word_uuid)
return Response(status_code=204)
except HTTPException:
raise
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書の更新に失敗しました。"
)

@app.post(
"/import_user_dict",
status_code=204,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def import_user_dict_words(
import_dict_data: Annotated[
dict[str, UserDictWord],
Body(description="インポートするユーザー辞書のデータ"),
],
override: Annotated[
bool, Query(description="重複したエントリがあった場合、上書きするかどうか")
],
) -> Response:
"""
他のユーザー辞書をインポートします。
"""
try:
import_user_dict(dict_data=import_dict_data, override=override)
return Response(status_code=204)
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書のインポートに失敗しました。"
)
app.include_router(user_dict.generate_router())

@app.get("/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"])
def supported_devices(
Expand Down
Empty file.
20 changes: 20 additions & 0 deletions voicevox_engine/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass

from fastapi import HTTPException


# 許可されていないAPIを無効化する
@dataclass
class MutableAPI:
enable: bool = True


mutable_api = MutableAPI()
tarepan marked this conversation as resolved.
Show resolved Hide resolved


async def check_disabled_mutable_api() -> None:
if not mutable_api.enable:
raise HTTPException(
status_code=403,
detail="エンジンの静的なデータを変更するAPIは無効化されています",
)
Empty file.
Loading
Loading