Skip to content

Commit

Permalink
整理: ユーザー辞書機能を API Router でモジュール化 (#1156)
Browse files Browse the repository at this point in the history
* refactor: user_dict 機能をサブサーバに切り出して整理

* refactor: router を生成関数に変更

* fix: lint

* fix: `router` 名を明瞭化

* fix: グローバル変数 FIXME 追加

* fix: lint
  • Loading branch information
tarepan authored Apr 14, 2024
1 parent d369cae commit a54d579
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 190 deletions.
200 changes: 10 additions & 190 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import re
import sys
import traceback
import zipfile
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
Expand All @@ -25,12 +24,17 @@
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError, parse_obj_as
from pydantic import parse_obj_as
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import FileResponse

from voicevox_engine import __version__
from voicevox_engine.app.dependencies import (
check_disabled_mutable_api,
deprecated_mutable_api,
)
from voicevox_engine.app.routers import user_dict
from voicevox_engine.cancellable_engine import CancellableEngine
from voicevox_engine.core.core_adapter import CoreAdapter
from voicevox_engine.core.core_initializer import initialize_cores
Expand Down Expand Up @@ -58,9 +62,7 @@
SpeakerInfo,
StyleIdNotFoundError,
SupportedDevicesInfo,
UserDictWord,
VvlibManifest,
WordTypes,
)
from voicevox_engine.morphing import (
get_morphable_targets,
Expand All @@ -80,15 +82,7 @@
TTSEngine,
make_tts_engines_from_cores,
)
from voicevox_engine.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY
from voicevox_engine.user_dict.user_dict import (
apply_word,
delete_word,
import_user_dict,
read_dict,
rewrite_word,
update_dict,
)
from voicevox_engine.user_dict.user_dict import update_dict
from voicevox_engine.utility.connect_base64_waves import (
ConnectBase64WavesException,
connect_base64_waves,
Expand Down Expand Up @@ -219,13 +213,8 @@ async def block_origin_middleware(
status_code=403, content={"detail": "Origin not allowed"}
)

# 許可されていないAPIを無効化する
async def check_disabled_mutable_api() -> None:
if disable_mutable_api:
raise HTTPException(
status_code=403,
detail="エンジンの静的なデータを変更するAPIは無効化されています",
)
if disable_mutable_api:
deprecated_mutable_api.enable = False

engine_manifest_data = EngineManifestLoader(
engine_root() / "engine_manifest.json", engine_root()
Expand Down Expand Up @@ -1092,176 +1081,7 @@ def is_initialized_speaker(
core = get_core(core_version)
return core.is_initialized_style_id_synthesis(style_id)

@app.get(
"/user_dict",
response_model=dict[str, UserDictWord],
response_description="単語のUUIDとその詳細",
tags=["ユーザー辞書"],
)
def get_user_dict_words() -> dict[str, UserDictWord]:
"""
ユーザー辞書に登録されている単語の一覧を返します。
単語の表層形(surface)は正規化済みの物を返します。
"""
try:
return read_dict()
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="辞書の読み込みに失敗しました。"
)

@app.post(
"/user_dict_word",
response_model=str,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def add_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
accent_type: Annotated[
int, Query(description="アクセント型(音が下がる場所を指す)")
],
word_type: Annotated[
WordTypes | None,
Query(
description="PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか"
),
] = None,
priority: Annotated[
int | None,
Query(
ge=MIN_PRIORITY,
le=MAX_PRIORITY,
description="単語の優先度(0から10までの整数)。数字が大きいほど優先度が高くなる。1から9までの値を指定することを推奨",
),
] = None,
) -> Response:
"""
ユーザー辞書に言葉を追加します。
"""
try:
word_uuid = apply_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
)
return Response(content=word_uuid)
except ValidationError as e:
raise HTTPException(
status_code=422, detail="パラメータに誤りがあります。\n" + str(e)
)
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書への追加に失敗しました。"
)

@app.put(
"/user_dict_word/{word_uuid}",
status_code=204,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def rewrite_user_dict_word(
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
accent_type: Annotated[
int, Query(description="アクセント型(音が下がる場所を指す)")
],
word_uuid: Annotated[str, FAPath(description="更新する言葉のUUID")],
word_type: Annotated[
WordTypes | None,
Query(
description="PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか"
),
] = None,
priority: Annotated[
int | None,
Query(
ge=MIN_PRIORITY,
le=MAX_PRIORITY,
description="単語の優先度(0から10までの整数)。数字が大きいほど優先度が高くなる。1から9までの値を指定することを推奨。",
),
] = None,
) -> Response:
"""
ユーザー辞書に登録されている言葉を更新します。
"""
try:
rewrite_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_uuid=word_uuid,
word_type=word_type,
priority=priority,
)
return Response(status_code=204)
except HTTPException:
raise
except ValidationError as e:
raise HTTPException(
status_code=422, detail="パラメータに誤りがあります。\n" + str(e)
)
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書の更新に失敗しました。"
)

@app.delete(
"/user_dict_word/{word_uuid}",
status_code=204,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def delete_user_dict_word(
word_uuid: Annotated[str, FAPath(description="削除する言葉のUUID")]
) -> Response:
"""
ユーザー辞書に登録されている言葉を削除します。
"""
try:
delete_word(word_uuid=word_uuid)
return Response(status_code=204)
except HTTPException:
raise
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書の更新に失敗しました。"
)

@app.post(
"/import_user_dict",
status_code=204,
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def import_user_dict_words(
import_dict_data: Annotated[
dict[str, UserDictWord],
Body(description="インポートするユーザー辞書のデータ"),
],
override: Annotated[
bool, Query(description="重複したエントリがあった場合、上書きするかどうか")
],
) -> Response:
"""
他のユーザー辞書をインポートします。
"""
try:
import_user_dict(dict_data=import_dict_data, override=override)
return Response(status_code=204)
except Exception:
traceback.print_exc()
raise HTTPException(
status_code=422, detail="ユーザー辞書のインポートに失敗しました。"
)
app.include_router(user_dict.generate_router())

@app.get("/supported_devices", response_model=SupportedDevicesInfo, tags=["その他"])
def supported_devices(
Expand Down
Empty file added voicevox_engine/app/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions voicevox_engine/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from fastapi import HTTPException


# 許可されていないAPIを無効化する
@dataclass
class MutableAPI:
enable: bool = True


# FIXME: グローバル変数が複数ファイルに分散しているため、DI 等で局所化する
deprecated_mutable_api = MutableAPI()


async def check_disabled_mutable_api() -> None:
if not deprecated_mutable_api.enable:
raise HTTPException(
status_code=403,
detail="エンジンの静的なデータを変更するAPIは無効化されています",
)
Empty file.
Loading

0 comments on commit a54d579

Please sign in to comment.