From 8954c03501b10655906ea0f00542cb1e9e469365 Mon Sep 17 00:00:00 2001 From: sabonerune <102559104+sabonerune@users.noreply.github.com> Date: Thu, 6 Oct 2022 00:49:23 +0900 Subject: [PATCH] =?UTF-8?q?TST:=20user=5Fdict=5Fstartup=5Fprocessing?= =?UTF-8?q?=E3=81=AE=E3=83=86=E3=82=B9=E3=83=88=E3=82=92=E8=BF=BD=E5=8A=A0?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit テストのためuser_dict内の関数に引数を追加 --- test/test_user_dict.py | 33 ++++++++++++++++++++++++++++++++- voicevox_engine/user_dict.py | 18 ++++++++++++------ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/test/test_user_dict.py b/test/test_user_dict.py index 250d65fa5..532797731 100644 --- a/test/test_user_dict.py +++ b/test/test_user_dict.py @@ -6,7 +6,7 @@ from unittest import TestCase from fastapi import HTTPException -from pyopenjtalk import unset_user_dict +from pyopenjtalk import g2p, unset_user_dict from voicevox_engine.model import UserDictWord, WordTypes from voicevox_engine.part_of_speech_data import MAX_PRIORITY, part_of_speech_data @@ -17,6 +17,7 @@ import_user_dict, read_dict, rewrite_word, + user_dict_startup_processing, ) # jsonとして保存される正しい形式の辞書データ @@ -315,3 +316,33 @@ def test_import_invalid_word(self): user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path, ) + + def test_startup_processing(self): + user_dict_path = self.tmp_dir_path / "test_startup_processing_dict.json" + compiled_dict_path = self.tmp_dir_path / "test_startup_processing_dict.dic" + user_dict_startup_processing( + user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path + ) + test_text = "テスト用の文字列" + success_pronunciation = "デフォルトノジショデハゼッタイニセイセイサレナイヨミ" + + # 既に辞書に登録されていないか確認する + self.assertNotEqual(g2p(text=test_text, kana=True), success_pronunciation) + + apply_word( + surface=test_text, + pronunciation=success_pronunciation, + accent_type=1, + priority=10, + user_dict_path=user_dict_path, + compiled_dict_path=compiled_dict_path, + ) + self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation) + + # 疑似的にエンジンを再起動する + unset_user_dict() + user_dict_startup_processing( + user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path + ) + + self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation) diff --git a/voicevox_engine/user_dict.py b/voicevox_engine/user_dict.py index 0f0e911d4..fadfa3e35 100644 --- a/voicevox_engine/user_dict.py +++ b/voicevox_engine/user_dict.py @@ -42,6 +42,7 @@ def write_to_json(user_dict: Dict[str, UserDictWord], user_dict_path: Path): def user_dict_startup_processing( default_dict_path: Path = default_dict_path, + user_dict_path: Path = user_dict_path, compiled_dict_path: Path = compiled_dict_path, ): pyopenjtalk.create_user_dict( @@ -51,12 +52,15 @@ def user_dict_startup_processing( pyopenjtalk.set_user_dict(str(compiled_dict_path.resolve(strict=True))) if user_dict_path.is_file(): update_dict( - default_dict_path=default_dict_path, compiled_dict_path=compiled_dict_path + default_dict_path=default_dict_path, + user_dict_path=user_dict_path, + compiled_dict_path=compiled_dict_path, ) def update_dict( default_dict_path: Path = default_dict_path, + user_dict_path: Path = user_dict_path, compiled_dict_path: Path = compiled_dict_path, ): with NamedTemporaryFile(encoding="utf-8", mode="w", delete=False) as f: @@ -67,7 +71,7 @@ def update_dict( if default_dict == default_dict.rstrip(): default_dict += "\n" f.write(default_dict) - user_dict = read_dict() + user_dict = read_dict(user_dict_path=user_dict_path) for word_uuid in user_dict: word = user_dict[word_uuid] f.write( @@ -185,7 +189,7 @@ def apply_word( word_uuid = str(uuid4()) user_dict[word_uuid] = word write_to_json(user_dict, user_dict_path) - update_dict(compiled_dict_path=compiled_dict_path) + update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) return word_uuid @@ -211,7 +215,7 @@ def rewrite_word( raise HTTPException(status_code=422, detail="UUIDに該当するワードが見つかりませんでした") user_dict[word_uuid] = word write_to_json(user_dict, user_dict_path) - update_dict(compiled_dict_path=compiled_dict_path) + update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) def delete_word( @@ -224,7 +228,7 @@ def delete_word( raise HTTPException(status_code=422, detail="IDに該当するワードが見つかりませんでした") del user_dict[word_uuid] write_to_json(user_dict, user_dict_path) - update_dict(compiled_dict_path=compiled_dict_path) + update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) def import_user_dict( @@ -263,7 +267,9 @@ def import_user_dict( new_dict = {**dict_data, **old_dict} write_to_json(user_dict=new_dict, user_dict_path=user_dict_path) update_dict( - default_dict_path=default_dict_path, compiled_dict_path=compiled_dict_path + default_dict_path=default_dict_path, + user_dict_path=user_dict_path, + compiled_dict_path=compiled_dict_path, )