Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: UserDictionary クラスを追加 #1222

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build_util/make_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import CoreAdapter
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.path_utility import engine_root


Expand Down Expand Up @@ -44,6 +45,7 @@ def generate_api_docs_html(schema: str) -> str:
preset_manager=PresetManager( # FIXME: impl MockPresetManager
preset_path=engine_root() / "presets.yaml",
),
user_dict=UserDictionary(),
)
api_schema = json.dumps(app.openapi())

Expand Down
4 changes: 4 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from voicevox_engine.setting.Setting import CorsPolicyMode
from voicevox_engine.setting.SettingLoader import USER_SETTING_PATH, SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.core_version_utility import get_latest_version
from voicevox_engine.utility.path_utility import engine_root
from voicevox_engine.utility.run_utility import decide_boolean_from_env
Expand Down Expand Up @@ -294,6 +295,8 @@ def main() -> None:
# ファイルの存在に関わらず指定されたパスをプリセットファイルとして使用する
preset_manager = PresetManager(preset_path)

use_dict = UserDictionary()

if arg_disable_mutable_api:
disable_mutable_api = True
else:
Expand All @@ -306,6 +309,7 @@ def main() -> None:
latest_core_version,
setting_loader,
preset_manager,
use_dict,
cancellable_engine,
root_dir,
cors_policy_mode,
Expand Down
3 changes: 3 additions & 0 deletions test/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from voicevox_engine.preset.PresetManager import PresetManager
from voicevox_engine.setting.SettingLoader import SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.core_version_utility import get_latest_version


Expand All @@ -26,13 +27,15 @@ def app_params(tmp_path: Path) -> dict[str, Any]:
preset_path = tmp_path / "presets.yaml"
shutil.copyfile(original_preset_path, preset_path)
preset_manager = PresetManager(preset_path)
user_dict = UserDictionary()

return {
"tts_engines": tts_engines,
"cores": cores,
"latest_core_version": latest_core_version,
"setting_loader": setting_loader,
"preset_manager": preset_manager,
"user_dict": user_dict,
}


Expand Down
138 changes: 70 additions & 68 deletions test/user_dict/test_user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,8 @@
)
from voicevox_engine.user_dict.user_dict import (
UserDictInputError,
UserDictionary,
_create_word,
tarepan marked this conversation as resolved.
Show resolved Hide resolved
apply_word,
delete_word,
import_user_dict,
read_dict,
rewrite_word,
update_dict,
)

# jsonとして保存される正しい形式の辞書データ
Expand Down Expand Up @@ -85,15 +80,22 @@ def tearDown(self) -> None:
self.tmp_dir.cleanup()

def test_read_not_exist_json(self) -> None:
user_dict = UserDictionary(user_dict_path=self.tmp_dir_path / "not_exist.json")
self.assertEqual(
read_dict(user_dict_path=(self.tmp_dir_path / "not_exist.json")),
user_dict.read_dict(),
{},
)

def test_create_word(self) -> None:
# 将来的に品詞などが追加された時にテストを増やす
self.assertEqual(
_create_word(surface="test", pronunciation="テスト", accent_type=1),
_create_word(
surface="test",
pronunciation="テスト",
accent_type=1,
word_type=None,
priority=None,
),
UserDictWord(
surface="test",
priority=5,
Expand All @@ -113,15 +115,13 @@ def test_create_word(self) -> None:
)

def test_apply_word_without_json(self) -> None:
user_dict_path = self.tmp_dir_path / "test_apply_word_without_json.json"
apply_word(
surface="test",
pronunciation="テスト",
accent_type=1,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_without_json.dic"),

user_dict = UserDictionary(
user_dict_path=self.tmp_dir_path / "test_apply_word_without_json.json",
compiled_dict_path=self.tmp_dir_path / "test_apply_word_without_json.dic",
)
res = read_dict(user_dict_path=user_dict_path)
user_dict.apply_word(surface="test", pronunciation="テスト", accent_type=1)
res = user_dict.read_dict()
self.assertEqual(len(res), 1)
new_word = get_new_word(res)
self.assertEqual(
Expand All @@ -138,14 +138,16 @@ def test_apply_word_with_json(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
apply_word(
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=self.tmp_dir_path / "test_apply_word_with_json.dic",
)
user_dict.apply_word(
surface="test2",
pronunciation="テストツー",
accent_type=3,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_with_json.dic"),
)
res = read_dict(user_dict_path=user_dict_path)
res = user_dict.read_dict()
self.assertEqual(len(res), 2)
new_word = get_new_word(res)
self.assertEqual(
Expand All @@ -162,33 +164,35 @@ def test_rewrite_word_invalid_id(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"),
)
self.assertRaises(
UserDictInputError,
rewrite_word,
user_dict.rewrite_word,
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"),
)

def test_rewrite_word_valid_id(self) -> None:
user_dict_path = self.tmp_dir_path / "test_rewrite_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
rewrite_word(
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=self.tmp_dir_path / "test_rewrite_word_valid_id.dic",
)
user_dict.rewrite_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_valid_id.dic"),
)
new_word = read_dict(user_dict_path=user_dict_path)[
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"
]
new_word = user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]
self.assertEqual(
(new_word.surface, new_word.pronunciation, new_word.accent_type),
("test2", "テストツー", 2),
Expand All @@ -199,25 +203,27 @@ def test_delete_word_invalid_id(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=self.tmp_dir_path / "test_delete_word_invalid_id.dic",
)
self.assertRaises(
UserDictInputError,
delete_word,
user_dict.delete_word,
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_invalid_id.dic"),
)

def test_delete_word_valid_id(self) -> None:
user_dict_path = self.tmp_dir_path / "test_delete_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
delete_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
user_dict = UserDictionary(
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_valid_id.dic"),
compiled_dict_path=self.tmp_dir_path / "test_delete_word_valid_id.dic",
)
self.assertEqual(len(read_dict(user_dict_path=user_dict_path)), 0)
user_dict.delete_word(word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e")
self.assertEqual(len(user_dict.read_dict()), 0)

def test_priority(self) -> None:
for pos in part_of_speech_data:
Expand All @@ -239,18 +245,18 @@ def test_import_dict(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word},
override=False,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.import_user_dict(
{"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word}, override=False
)
self.assertEqual(
read_dict(user_dict_path)["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"],
user_dict.read_dict()["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"],
import_word,
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]),
)

Expand All @@ -260,14 +266,14 @@ def test_import_dict_no_override(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word},
override=False,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, override=False
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]),
)

Expand All @@ -277,14 +283,14 @@ def test_import_dict_override(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, override=True
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
import_word,
)

Expand All @@ -296,15 +302,16 @@ def test_import_invalid_word(self) -> None:
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
self.assertRaises(
AssertionError,
import_user_dict,
user_dict.import_user_dict,
{
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_accent_associative_rule_word
},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
invalid_pos_word = deepcopy(import_word)
invalid_pos_word.context_id = 2
Expand All @@ -314,39 +321,34 @@ def test_import_invalid_word(self) -> None:
invalid_pos_word.part_of_speech_detail_3 = "*"
self.assertRaises(
ValueError,
import_user_dict,
user_dict.import_user_dict,
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_pos_word},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)

def test_update_dict(self) -> None:
user_dict_path = self.tmp_dir_path / "test_update_dict.json"
compiled_dict_path = self.tmp_dir_path / "test_update_dict.dic"
update_dict(
user_dict = UserDictionary(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.update_dict()
test_text = "テスト用の文字列"
success_pronunciation = "デフォルトノジショデハゼッタイニセイセイサレナイヨミ"

# 既に辞書に登録されていないか確認する
self.assertNotEqual(g2p(text=test_text, kana=True), success_pronunciation)

apply_word(
user_dict.apply_word(
surface=test_text,
pronunciation=success_pronunciation,
accent_type=1,
priority=10,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation)

# 疑似的にエンジンを再起動する
unset_user_dict()
update_dict(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
user_dict.update_dict()

self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation)
7 changes: 4 additions & 3 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from voicevox_engine.setting.Setting import CorsPolicyMode
from voicevox_engine.setting.SettingLoader import SettingHandler
from voicevox_engine.tts_pipeline.tts_engine import TTSEngine
from voicevox_engine.user_dict.user_dict import update_dict
from voicevox_engine.user_dict.user_dict import UserDictionary
from voicevox_engine.utility.path_utility import engine_root, get_save_dir


Expand All @@ -36,6 +36,7 @@ def generate_app(
latest_core_version: str,
setting_loader: SettingHandler,
preset_manager: PresetManager,
user_dict: UserDictionary,
cancellable_engine: CancellableEngine | None = None,
root_dir: Path | None = None,
cors_policy_mode: CorsPolicyMode = CorsPolicyMode.localapps,
Expand All @@ -48,7 +49,7 @@ def generate_app(

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
update_dict()
user_dict.update_dict()
yield

app = FastAPI(
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_core(core_version: str | None) -> CoreAdapter:
app.include_router(
generate_library_router(engine_manifest_data, library_manager)
)
app.include_router(generate_user_dict_router())
app.include_router(generate_user_dict_router(user_dict))
app.include_router(
generate_engine_info_router(get_core, cores, engine_manifest_data)
)
Expand Down
Loading