Skip to content

Commit

Permalink
整理: 内部型 WordProperty を追加 (VOICEVOX#1333)
Browse files Browse the repository at this point in the history
* refactor: 内部型 `SimpleUserDictWord` を追加

* refactor: `SimpleUserDictWord` の docstring を充実

* refactor: `SimpleUserDictWord` → `WordProperty` にリネーム

* fix: lint

* fix: import 元を間接から直接へ修正

* Update voicevox_engine/user_dict/user_dict_word.py

---------

Co-authored-by: Hiroshiba <[email protected]>
  • Loading branch information
tarepan and Hiroshiba authored Jun 2, 2024
1 parent 5840f9b commit 4e13143
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 137 deletions.
49 changes: 23 additions & 26 deletions test/user_dict/test_user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from voicevox_engine.user_dict.user_dict_word import (
MAX_PRIORITY,
UserDictInputError,
WordProperty,
create_word,
part_of_speech_data,
)
Expand Down Expand Up @@ -75,11 +76,7 @@ def test_read_not_exist_json(tmp_path: Path) -> None:
def test_create_word() -> None:
# 将来的に品詞などが追加された時にテストを増やす
assert create_word(
surface="test",
pronunciation="テスト",
accent_type=1,
word_type=None,
priority=None,
WordProperty(surface="test", pronunciation="テスト", accent_type=1)
) == UserDictWord(
surface="test",
priority=5,
Expand All @@ -104,7 +101,9 @@ def test_apply_word_without_json(tmp_path: Path) -> None:
user_dict_path=tmp_path / "test_apply_word_without_json.json",
compiled_dict_path=tmp_path / "test_apply_word_without_json.dic",
)
user_dict.apply_word(surface="test", pronunciation="テスト", accent_type=1)
user_dict.apply_word(
WordProperty(surface="test", pronunciation="テスト", accent_type=1)
)
res = user_dict.read_dict()
assert len(res) == 1
new_word = get_new_word(res)
Expand All @@ -125,9 +124,7 @@ def test_apply_word_with_json(tmp_path: Path) -> None:
compiled_dict_path=tmp_path / "test_apply_word_with_json.dic",
)
user_dict.apply_word(
surface="test2",
pronunciation="テストツー",
accent_type=3,
WordProperty(surface="test2", pronunciation="テストツー", accent_type=3)
)
res = user_dict.read_dict()
assert len(res) == 2
Expand All @@ -150,10 +147,8 @@ def test_rewrite_word_invalid_id(tmp_path: Path) -> None:
)
with pytest.raises(UserDictInputError):
user_dict.rewrite_word(
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
surface="test2",
pronunciation="テストツー",
accent_type=2,
"c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
WordProperty(surface="test2", pronunciation="テストツー", accent_type=2),
)


Expand All @@ -167,10 +162,8 @@ def test_rewrite_word_valid_id(tmp_path: Path) -> None:
compiled_dict_path=tmp_path / "test_rewrite_word_valid_id.dic",
)
user_dict.rewrite_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
surface="test2",
pronunciation="テストツー",
accent_type=2,
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
WordProperty(surface="test2", pronunciation="テストツー", accent_type=2),
)
new_word = user_dict.read_dict()["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]
assert (new_word.surface, new_word.pronunciation, new_word.accent_type) == (
Expand Down Expand Up @@ -211,11 +204,13 @@ def test_priority() -> None:
for i in range(MAX_PRIORITY + 1):
assert (
create_word(
surface="test",
pronunciation="テスト",
accent_type=1,
word_type=pos,
priority=i,
WordProperty(
surface="test",
pronunciation="テスト",
accent_type=1,
word_type=pos,
priority=i,
)
).priority
== i
)
Expand Down Expand Up @@ -316,10 +311,12 @@ def test_update_dict(tmp_path: Path) -> None:
assert g2p(text=test_text, kana=True) != success_pronunciation

user_dict.apply_word(
surface=test_text,
pronunciation=success_pronunciation,
accent_type=1,
priority=10,
WordProperty(
surface=test_text,
pronunciation=success_pronunciation,
accent_type=1,
priority=10,
)
)
assert g2p(text=test_text, kana=True) == success_pronunciation

Expand Down
27 changes: 16 additions & 11 deletions voicevox_engine/app/routers/user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MAX_PRIORITY,
MIN_PRIORITY,
UserDictInputError,
WordProperty,
)

from ..dependencies import check_disabled_mutable_api
Expand Down Expand Up @@ -65,11 +66,13 @@ def add_user_dict_word(
"""
try:
word_uuid = user_dict.apply_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
WordProperty(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
)
)
return word_uuid
except ValidationError as e:
Expand Down Expand Up @@ -115,12 +118,14 @@ def rewrite_user_dict_word(
"""
try:
user_dict.rewrite_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_uuid=word_uuid,
word_type=word_type,
priority=priority,
word_uuid,
WordProperty(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
),
)
except ValidationError as e:
raise HTTPException(
Expand Down
77 changes: 7 additions & 70 deletions voicevox_engine/user_dict/user_dict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .model import UserDictWord, WordTypes
from .user_dict_word import (
UserDictInputError,
WordProperty,
cost2priority,
create_word,
part_of_speech_data,
Expand Down Expand Up @@ -296,44 +297,12 @@ def import_user_dict(
compiled_dict_path=self._compiled_dict_path,
)

def apply_word(
self,
surface: str,
pronunciation: str,
accent_type: int,
word_type: WordTypes | None = None,
priority: int | None = None,
) -> str:
"""
新規単語を追加する。
Parameters
----------
surface : str
単語情報
pronunciation : str
単語情報
accent_type : int
単語情報
word_type : WordTypes | None
品詞
priority : int | None
優先度
Returns
-------
word_uuid : UserDictWord
追加された単語に発行されたUUID
"""
def apply_word(self, word_property: WordProperty) -> str:
"""新規単語を追加し、その単語に割り当てられた UUID を返す。"""
# 新規単語の追加による辞書データの更新
word = create_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
)
user_dict = _read_dict(user_dict_path=self._user_dict_path)
word_uuid = str(uuid4())
user_dict[word_uuid] = word
user_dict[word_uuid] = create_word(word_property)

# 更新された辞書データの保存と適用
_write_to_json(user_dict, self._user_dict_path)
Expand All @@ -345,45 +314,13 @@ def apply_word(

return word_uuid

def rewrite_word(
self,
word_uuid: str,
surface: str,
pronunciation: str,
accent_type: int,
word_type: WordTypes | None = None,
priority: int | None = None,
) -> None:
"""
既存単語を上書き更新する。
Parameters
----------
word_uuid : str
単語UUID
surface : str
単語情報
pronunciation : str
単語情報
accent_type : int
単語情報
word_type : WordTypes | None
品詞
priority : int | None
優先度
"""
word = create_word(
surface=surface,
pronunciation=pronunciation,
accent_type=accent_type,
word_type=word_type,
priority=priority,
)

def rewrite_word(self, word_uuid: str, word_property: WordProperty) -> None:
"""単語 UUID で指定された単語を上書き更新する。"""
# 既存単語の上書きによる辞書データの更新
user_dict = _read_dict(user_dict_path=self._user_dict_path)
if word_uuid not in user_dict:
raise UserDictInputError("UUIDに該当するワードが見つかりませんでした")
user_dict[word_uuid] = word
user_dict[word_uuid] = create_word(word_property)

# 更新された辞書データの保存と適用
_write_to_json(user_dict, self._user_dict_path)
Expand Down
53 changes: 23 additions & 30 deletions voicevox_engine/user_dict/user_dict_word.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""ユーザー辞書を構成する言葉(単語)関連の処理"""

from dataclasses import dataclass

import numpy as np
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -164,43 +166,34 @@ class PartOfSpeechDetail(BaseModel):
}


def create_word(
surface: str,
pronunciation: str,
accent_type: int,
word_type: WordTypes | None,
priority: int | None,
) -> UserDictWord:
"""
単語オブジェクトの生成
Parameters
----------
surface : str
単語情報
pronunciation : str
単語情報
accent_type : int
単語情報
word_type : WordTypes | None
品詞
priority : int | None
優先度
Returns
-------
: UserDictWord
単語オブジェクト
"""
@dataclass
class WordProperty:
"""単語属性のあつまり"""

surface: str # 単語情報
pronunciation: str # 単語情報
accent_type: int # 単語情報
word_type: WordTypes | None = None # 品詞
priority: int | None = None # 優先度


def create_word(word_property: WordProperty) -> UserDictWord:
"""単語オブジェクトを生成する。"""
word_type: WordTypes | None = word_property.word_type
if word_type is None:
word_type = WordTypes.PROPER_NOUN
if word_type not in part_of_speech_data.keys():
raise UserDictInputError("不明な品詞です")

priority: int | None = word_property.priority
if priority is None:
priority = 5
if not MIN_PRIORITY <= priority <= MAX_PRIORITY:
raise UserDictInputError("優先度の値が無効です")

pos_detail = part_of_speech_data[word_type]
return UserDictWord(
surface=surface,
surface=word_property.surface,
context_id=pos_detail.context_id,
priority=priority,
part_of_speech=pos_detail.part_of_speech,
Expand All @@ -210,9 +203,9 @@ def create_word(
inflectional_type="*",
inflectional_form="*",
stem="*",
yomi=pronunciation,
pronunciation=pronunciation,
accent_type=accent_type,
yomi=word_property.pronunciation,
pronunciation=word_property.pronunciation,
accent_type=word_property.accent_type,
mora_count=None,
accent_associative_rule="*",
)
Expand Down

0 comments on commit 4e13143

Please sign in to comment.