Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: コア・エンジンでバージョンを指定しない場合、暗黙的に最新版を取得する処理を削除 #1317

Merged
merged 18 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def main() -> None:
)
tts_engines = make_tts_engines_from_cores(core_manager)
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"
latest_core_version = tts_engines.latest_version()
latest_core_version = core_manager.latest_version()

# Cancellable Engine
enable_cancellable_synthesis: bool = args.enable_cancellable_synthesis
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def app_params(tmp_path: Path) -> dict[str, Any]:
core_manager = initialize_cores(use_gpu=False, enable_mock=True)
tts_engines = make_tts_engines_from_cores(core_manager)
latest_core_version = tts_engines.latest_version()
latest_core_version = core_manager.latest_version()
setting_loader = SettingHandler(Path("./not_exist.yaml"))

# 隔離されたプリセットの生成
Expand Down
37 changes: 25 additions & 12 deletions test/test_core_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,38 @@ def test_cores_latest_version() -> None:
assert true_latest_version == latest_version


def test_cores_get_core_specified() -> None:
"""CoreManager.get_core() で登録済みコアをバージョン指定して取得できる。"""
def test_cores_convert_version_format_non_latest() -> None:
"""CoreManager.convert_version_format() で明示的バージョンが維持される。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
core2 = CoreAdapter(MockCoreWrapper())
core_manager.register_core(core1, "0.0.1")
core_manager.register_core(core2, "0.0.2")
api_format_version = "0.0.2"
# Expects
true_acquired_core = core2
true_version = "0.0.2"
# Outputs
acquired_core = core_manager.get_core("0.0.2")
version = core_manager.convert_version_format(api_format_version)

# Test
assert true_acquired_core == acquired_core
assert true_version == version


def test_cores_convert_version_format_latest() -> None:
"""CoreManager.convert_version_format() で latest 表現が変換される。"""
# Inputs
core_manager = CoreManager()
core_manager.register_core(CoreAdapter(MockCoreWrapper()), "0.0.1")
core_manager.register_core(CoreAdapter(MockCoreWrapper()), "0.0.2")
api_format_version = None
# Expects
true_version = "0.0.2"
# Outputs
version = core_manager.convert_version_format(api_format_version)

# Test
assert true_version == version


def test_cores_get_core_latest() -> None:
"""CoreManager.get_core() で最新版コアをバージョン未指定で取得できる。"""
def test_cores_get_core_existing() -> None:
"""CoreManager.get_core() で登録済みコアを取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
Expand All @@ -78,7 +91,7 @@ def test_cores_get_core_latest() -> None:
# Expects
true_acquired_core = core2
# Outputs
acquired_core = core_manager.get_core()
acquired_core = core_manager.get_core("0.0.2")

# Test
assert true_acquired_core == acquired_core
Expand Down
36 changes: 2 additions & 34 deletions test/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,8 @@ def test_tts_engines_versions() -> None:
assert true_versions == versions


def test_tts_engines_latest_version() -> None:
"""TTSEngineManager.latest_version() で最新バージョンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_latest_version = "0.0.2"
# Outputs
latest_version = tts_engines.latest_version()

# Test
assert true_latest_version == latest_version


def test_tts_engines_get_engine_specified() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンをバージョン指定して取得できる。"""
def test_tts_engines_get_engine_existing() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
Expand All @@ -63,23 +48,6 @@ def test_tts_engines_get_engine_specified() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine() で最新版 TTS エンジンをバージョン未指定で取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
# Expects
true_acquired_tts_engine = tts_engine2
# Outputs
acquired_tts_engine = tts_engines.get_engine()

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ async def core_versions() -> Response:
)
def supported_devices(core_version: str | None = None) -> Response:
"""対応デバイスの一覧を取得します。"""
supported_devices = core_manager.get_core(core_version).supported_devices
version = core_manager.convert_version_format(core_version)
supported_devices = core_manager.get_core(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return Response(
Expand Down
8 changes: 5 additions & 3 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def morphable_targets(
プロパティが存在しない場合は、モーフィングが許可されているとみなします。
返り値のスタイルIDはstring型なので注意。
"""
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
core = core_manager.get_core(version)

try:
speakers = metas_store.load_combined_metas(core=core)
Expand Down Expand Up @@ -94,8 +95,9 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)

try:
speakers = metas_store.load_combined_metas(core=core)
Expand Down
16 changes: 11 additions & 5 deletions voicevox_engine/app/routers/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def generate_speaker_router(
def speakers(
core_version: str | None = None,
) -> list[Speaker]:
speakers = metas_store.load_combined_metas(core_manager.get_core(core_version))
version = core_manager.convert_version_format(core_version)
speakers = metas_store.load_combined_metas(core_manager.get_core(version))
return filter_speakers_and_styles(speakers, "speaker")

@router.get("/speaker_info", tags=["その他"])
Expand Down Expand Up @@ -77,9 +78,11 @@ def _speaker_info(
# {speaker_uuid_1}/
# ...

version = core_manager.convert_version_format(core_version)

# 該当話者の検索
speakers = parse_obj_as(
list[Speaker], json.loads(core_manager.get_core(core_version).speakers)
list[Speaker], json.loads(core_manager.get_core(version).speakers)
)
speakers = filter_speakers_and_styles(speakers, speaker_or_singer)
for i in range(len(speakers)):
Expand Down Expand Up @@ -147,7 +150,8 @@ def _speaker_info(
def singers(
core_version: str | None = None,
) -> list[Speaker]:
singers = metas_store.load_combined_metas(core_manager.get_core(core_version))
version = core_manager.convert_version_format(core_version)
singers = metas_store.load_combined_metas(core_manager.get_core(version))
return filter_speakers_and_styles(singers, "singer")

@router.get("/singer_info", tags=["その他"])
Expand Down Expand Up @@ -180,7 +184,8 @@ def initialize_speaker(
指定されたスタイルを初期化します。
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
core = core_manager.get_core(version)
core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)

@router.get("/is_initialized_speaker", tags=["その他"])
Expand All @@ -191,7 +196,8 @@ def is_initialized_speaker(
"""
指定されたスタイルが初期化されているかどうかを返します。
"""
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
core = core_manager.get_core(version)
return core.is_initialized_style_id_synthesis(style_id)

return router
42 changes: 27 additions & 15 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
Expand Down Expand Up @@ -82,8 +83,9 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
presets = preset_manager.load_presets()
except PresetInputError as err:
Expand Down Expand Up @@ -139,7 +141,8 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
if is_kana:
try:
return engine.create_accent_phrases_from_kana(text, style_id)
Expand All @@ -160,7 +163,8 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
return engine.update_length_and_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -173,7 +177,8 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
return engine.update_length(accent_phrases, style_id)

@router.post(
Expand All @@ -186,7 +191,8 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
return engine.update_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -213,7 +219,8 @@ def synthesis(
] = True,
core_version: str | None = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)
Expand Down Expand Up @@ -253,8 +260,9 @@ def cancellable_synthesis(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
version = core_manager.convert_version_format(core_version)
f_name = cancellable_engine._synthesis_impl(
query, style_id, request, core_version=core_version
query, style_id, request, version=version
)
if f_name == "":
raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down Expand Up @@ -285,7 +293,8 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -327,8 +336,9 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
)
Expand All @@ -353,7 +363,8 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[float]:
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
return engine.create_sing_volume_from_phoneme_and_f0(
score, frame_audio_query.phonemes, frame_audio_query.f0, style_id
)
Expand All @@ -378,7 +389,8 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
engine = tts_engines.get_engine(core_version)
version = core_manager.convert_version_format(core_version)
engine = tts_engines.get_engine(version)
wave = engine.frame_synthsize_wave(query, style_id)

with NamedTemporaryFile(delete=False) as f:
Expand Down
14 changes: 6 additions & 8 deletions voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _synthesis_impl(
query: AudioQuery,
style_id: StyleId,
request: Request,
core_version: str | None,
version: str,
) -> str:
"""
音声合成を行う関数
Expand All @@ -157,7 +157,7 @@ def _synthesis_impl(
request: fastapi.Request
接続確立時に受け取ったものをそのまま渡せばよい
https://fastapi.tiangolo.com/advanced/using-request-directly/
core_version: str
version: str

Returns
-------
Expand All @@ -167,7 +167,7 @@ def _synthesis_impl(
proc, sub_proc_con1 = self.procs_and_cons.get()
self.watch_con_list.append((request, proc))
try:
sub_proc_con1.send((query, style_id, core_version))
sub_proc_con1.send((query, style_id, version))
f_name = sub_proc_con1.recv()
if isinstance(f_name, str):
audio_file_name = f_name
Expand Down Expand Up @@ -240,11 +240,9 @@ def start_synthesis_subprocess(
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"
while True:
try:
query, style_id, core_version = sub_proc_con.recv()
if core_version is None:
_engine = tts_engines.get_engine()
elif tts_engines.has_engine(core_version):
_engine = tts_engines.get_engine(core_version)
query, style_id, version = sub_proc_con.recv()
if tts_engines.has_engine(version):
_engine = tts_engines.get_engine(version)
else:
# バージョンが見つからないエラー
sub_proc_con.send("")
Expand Down
Loading