From 1ee11273834fa0ff63e1a85c3043cc8b028d1d1a Mon Sep 17 00:00:00 2001 From: tarepan Date: Sun, 23 Jun 2024 08:33:04 +0000 Subject: [PATCH 1/5] =?UTF-8?q?add:=20`.get=5Fengine()`=20=E3=81=AB=20None?= =?UTF-8?q?=20=E5=85=A5=E5=8A=9B=E3=81=AB=E3=82=88=E3=82=8B=20latest=20?= =?UTF-8?q?=E5=8F=96=E5=BE=97=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/unit/tts_pipeline/test_tts_engines.py | 19 +++++++++++++ voicevox_engine/app/routers/morphing.py | 2 +- voicevox_engine/app/routers/tts_pipeline.py | 30 ++++++++------------- voicevox_engine/tts_pipeline/tts_engine.py | 8 +++++- 4 files changed, 38 insertions(+), 21 deletions(-) diff --git a/test/unit/tts_pipeline/test_tts_engines.py b/test/unit/tts_pipeline/test_tts_engines.py index 1fa32c904..be90ce6a2 100644 --- a/test/unit/tts_pipeline/test_tts_engines.py +++ b/test/unit/tts_pipeline/test_tts_engines.py @@ -48,6 +48,25 @@ def test_tts_engines_get_engine_existing() -> None: assert true_acquired_tts_engine == acquired_tts_engine +def test_tts_engines_get_engine_latest() -> None: + """TTSEngineManager.get_engine(None) で最新版の TTS エンジンを取得できる。""" + # Inputs + tts_engines = TTSEngineManager() + tts_engine1 = MockTTSEngine() + tts_engine2 = MockTTSEngine() + tts_engine3 = MockTTSEngine() + tts_engines.register_engine(tts_engine1, "0.0.1") + tts_engines.register_engine(tts_engine2, "0.0.2") + tts_engines.register_engine(tts_engine3, "0.1.0") + # Expects + true_acquired_tts_engine = tts_engine3 + # Outputs + acquired_tts_engine = tts_engines.get_engine(None) + + # Test + assert true_acquired_tts_engine == acquired_tts_engine + + def test_tts_engines_get_engine_missing() -> None: """TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。""" # Inputs diff --git a/voicevox_engine/app/routers/morphing.py b/voicevox_engine/app/routers/morphing.py index b4fcd734a..f6e63c9db 100644 --- a/voicevox_engine/app/routers/morphing.py +++ b/voicevox_engine/app/routers/morphing.py @@ -94,7 +94,7 @@ def _synthesis_morphing( モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。 """ version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) core = core_manager.get_core(version) # モーフィングが許可されないキャラクターペアを拒否する diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 4a2159a09..48de16f3d 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -86,7 +86,7 @@ def audio_query( 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) core = core_manager.get_core(version) accent_phrases = engine.create_accent_phrases(text, style_id) return AudioQuery( @@ -118,7 +118,7 @@ def audio_query_from_preset( 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) core = core_manager.get_core(version) try: presets = preset_manager.load_presets() @@ -177,8 +177,7 @@ def accent_phrases( * アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。 * アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。 """ - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) if is_kana: try: return engine.create_accent_phrases_from_kana(text, style_id) @@ -199,8 +198,7 @@ def mora_data( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) return engine.update_length_and_pitch(accent_phrases, style_id) @router.post( @@ -213,8 +211,7 @@ def mora_length( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) return engine.update_length(accent_phrases, style_id) @router.post( @@ -227,8 +224,7 @@ def mora_pitch( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) return engine.update_pitch(accent_phrases, style_id) @router.post( @@ -255,8 +251,7 @@ def synthesis( ] = True, core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) wave = engine.synthesize_wave( query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak ) @@ -333,8 +328,7 @@ def multi_synthesis( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) sampling_rate = queries[0].outputSamplingRate with NamedTemporaryFile(delete=False) as f: @@ -377,7 +371,7 @@ def sing_frame_audio_query( 歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) core = core_manager.get_core(version) try: phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume( @@ -406,8 +400,7 @@ def sing_frame_volume( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[float]: - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) try: return engine.create_sing_volume_from_phoneme_and_f0( score, frame_audio_query.phonemes, frame_audio_query.f0, style_id @@ -435,8 +428,7 @@ def frame_synthesis( """ 歌唱音声合成を行います。 """ - version = core_version or core_manager.latest_version() - engine = tts_engines.get_engine(version) + engine = tts_engines.get_engine(core_version) try: wave = engine.frame_synthsize_wave(query, style_id) except TalkSingInvalidInputError as e: diff --git a/voicevox_engine/tts_pipeline/tts_engine.py b/voicevox_engine/tts_pipeline/tts_engine.py index 9d3c57248..f56042b2a 100644 --- a/voicevox_engine/tts_pipeline/tts_engine.py +++ b/voicevox_engine/tts_pipeline/tts_engine.py @@ -8,6 +8,8 @@ from numpy.typing import NDArray from soxr import resample +from voicevox_engine.utility.core_version_utility import get_latest_version + from ..core.core_adapter import CoreAdapter from ..core.core_initializer import CoreManager from ..core.core_wrapper import CoreWrapper @@ -701,12 +703,16 @@ def versions(self) -> list[str]: """登録されたエンジンのバージョン一覧を取得する。""" return list(self._engines.keys()) + def _latest_version(self) -> str: + return get_latest_version(self.versions()) + def register_engine(self, engine: TTSEngine, version: str) -> None: """エンジンを登録する。""" self._engines[version] = engine - def get_engine(self, version: str) -> TTSEngine: + def get_engine(self, core_version: str | None) -> TTSEngine: """指定バージョンのエンジンを取得する。""" + version = core_version or self._latest_version() if version in self._engines: return self._engines[version] From dfc4bd0543acf04e0114a0173dec4080e811d92e Mon Sep 17 00:00:00 2001 From: tarepan Date: Tue, 25 Jun 2024 19:20:11 +0000 Subject: [PATCH 2/5] =?UTF-8?q?refactor:=20Cancellable=20=E3=81=AE=20versi?= =?UTF-8?q?on=20=E3=82=92=20core=5Fversion=20=E3=81=B8=E5=A4=89=E6=9B=B4?= =?UTF-8?q?=E3=81=97=E3=81=A6=E6=95=B4=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- voicevox_engine/app/application.py | 4 +--- voicevox_engine/app/routers/tts_pipeline.py | 5 +---- voicevox_engine/cancellable_engine.py | 14 +++++++------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/voicevox_engine/app/application.py b/voicevox_engine/app/application.py index 862b615e4..83b6c819e 100644 --- a/voicevox_engine/app/application.py +++ b/voicevox_engine/app/application.py @@ -69,9 +69,7 @@ def generate_app( metas_store = MetasStore(speaker_info_dir, resource_manager) app.include_router( - generate_tts_pipeline_router( - tts_engines, core_manager, preset_manager, cancellable_engine - ) + generate_tts_pipeline_router(tts_engines, preset_manager, cancellable_engine) ) app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store)) app.include_router( diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 907eaefe6..23f3cbec8 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -15,7 +15,6 @@ CancellableEngine, CancellableEngineInternalError, ) -from voicevox_engine.core.core_initializer import CoreManager from voicevox_engine.metas.Metas import StyleId from voicevox_engine.model import AudioQuery from voicevox_engine.preset.preset_manager import ( @@ -65,7 +64,6 @@ def __init__(self, err: ParseKanaError): def generate_tts_pipeline_router( tts_engines: TTSEngineManager, - core_manager: CoreManager, preset_manager: PresetManager, cancellable_engine: CancellableEngine | None, ) -> APIRouter: @@ -287,10 +285,9 @@ def cancellable_synthesis( status_code=404, detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。", ) - version = core_version or core_manager.latest_version() try: f_name = cancellable_engine._synthesis_impl( - query, style_id, request, version=version + query, style_id, request, core_version=core_version ) except CancellableEngineInternalError as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/voicevox_engine/cancellable_engine.py b/voicevox_engine/cancellable_engine.py index 2899fb5e4..7e0dbcf7d 100644 --- a/voicevox_engine/cancellable_engine.py +++ b/voicevox_engine/cancellable_engine.py @@ -149,7 +149,7 @@ def _synthesis_impl( query: AudioQuery, style_id: StyleId, request: Request, - version: str, + core_version: str | None, ) -> str: """ 音声合成を行う関数 @@ -163,7 +163,7 @@ def _synthesis_impl( request: fastapi.Request 接続確立時に受け取ったものをそのまま渡せばよい https://fastapi.tiangolo.com/advanced/using-request-directly/ - version: str + core_version Returns ------- @@ -173,7 +173,7 @@ def _synthesis_impl( proc, sub_proc_con1 = self.procs_and_cons.get() self.watch_con_list.append((request, proc)) try: - sub_proc_con1.send((query, style_id, version)) + sub_proc_con1.send((query, style_id, core_version)) f_name = sub_proc_con1.recv() if isinstance(f_name, str): audio_file_name = f_name @@ -244,10 +244,10 @@ def start_synthesis_subprocess( assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。" while True: try: - query, style_id, version = sub_proc_con.recv() - if tts_engines.has_engine(version): - _engine = tts_engines.get_engine(version) - else: + query, style_id, core_version = sub_proc_con.recv() + try: + _engine = tts_engines.get_engine(core_version) + except Exception: # バージョンが見つからないエラー sub_proc_con.send("") continue From 76b220bd1c6f43287dd8b578c83d38e614c209b2 Mon Sep 17 00:00:00 2001 From: tarepan Date: Thu, 27 Jun 2024 05:43:51 +0000 Subject: [PATCH 3/5] =?UTF-8?q?refactor:=20morphing=20=E3=81=AE=E3=82=B3?= =?UTF-8?q?=E3=82=A2=E7=9B=B4=E6=8E=A5=E4=BE=9D=E5=AD=98=E3=82=92=E5=89=8A?= =?UTF-8?q?=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- voicevox_engine/app/application.py | 2 +- voicevox_engine/app/routers/morphing.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/voicevox_engine/app/application.py b/voicevox_engine/app/application.py index 1d61ab5bc..8315e2dc6 100644 --- a/voicevox_engine/app/application.py +++ b/voicevox_engine/app/application.py @@ -82,7 +82,7 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]: app.include_router( generate_tts_pipeline_router(tts_engines, preset_manager, cancellable_engine) ) - app.include_router(generate_morphing_router(tts_engines, core_manager, metas_store)) + app.include_router(generate_morphing_router(tts_engines, metas_store)) app.include_router( generate_preset_router(preset_manager, verify_mutability_allowed) ) diff --git a/voicevox_engine/app/routers/morphing.py b/voicevox_engine/app/routers/morphing.py index 42d4ebf91..d0e3b6ba8 100644 --- a/voicevox_engine/app/routers/morphing.py +++ b/voicevox_engine/app/routers/morphing.py @@ -10,7 +10,6 @@ from starlette.background import BackgroundTask from starlette.responses import FileResponse -from voicevox_engine.core.core_initializer import CoreManager from voicevox_engine.metas.Metas import StyleId from voicevox_engine.metas.MetasStore import MetasStore from voicevox_engine.model import AudioQuery @@ -34,9 +33,7 @@ def generate_morphing_router( - tts_engines: TTSEngineManager, - core_manager: CoreManager, - metas_store: MetasStore, + tts_engines: TTSEngineManager, metas_store: MetasStore ) -> APIRouter: """モーフィング API Router を生成する""" router = APIRouter(tags=["音声合成"]) From 3b0428dd7c05abd94d3e2e506e98504e22045a4a Mon Sep 17 00:00:00 2001 From: tarepan Date: Thu, 27 Jun 2024 10:07:09 +0000 Subject: [PATCH 4/5] =?UTF-8?q?refactor:=20None=20=E2=86=92=20LATEST=5FVER?= =?UTF-8?q?SION=20=E3=81=AE=E5=A4=89=E6=8F=9B=E3=82=92=E5=B0=8E=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/unit/tts_pipeline/test_tts_engines.py | 6 +-- voicevox_engine/app/routers/morphing.py | 5 ++- voicevox_engine/app/routers/tts_pipeline.py | 43 ++++++++++++++------- voicevox_engine/cancellable_engine.py | 12 +++--- voicevox_engine/tts_pipeline/tts_engine.py | 12 ++++-- 5 files changed, 50 insertions(+), 28 deletions(-) diff --git a/test/unit/tts_pipeline/test_tts_engines.py b/test/unit/tts_pipeline/test_tts_engines.py index be90ce6a2..c49bf5462 100644 --- a/test/unit/tts_pipeline/test_tts_engines.py +++ b/test/unit/tts_pipeline/test_tts_engines.py @@ -4,7 +4,7 @@ from fastapi import HTTPException from voicevox_engine.dev.tts_engine.mock import MockTTSEngine -from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager +from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager def test_tts_engines_register_engine() -> None: @@ -49,7 +49,7 @@ def test_tts_engines_get_engine_existing() -> None: def test_tts_engines_get_engine_latest() -> None: - """TTSEngineManager.get_engine(None) で最新版の TTS エンジンを取得できる。""" + """TTSEngineManager.get_engine(LATEST_VERSION) で最新版の TTS エンジンを取得できる。""" # Inputs tts_engines = TTSEngineManager() tts_engine1 = MockTTSEngine() @@ -61,7 +61,7 @@ def test_tts_engines_get_engine_latest() -> None: # Expects true_acquired_tts_engine = tts_engine3 # Outputs - acquired_tts_engine = tts_engines.get_engine(None) + acquired_tts_engine = tts_engines.get_engine(LATEST_VERSION) # Test assert true_acquired_tts_engine == acquired_tts_engine diff --git a/voicevox_engine/app/routers/morphing.py b/voicevox_engine/app/routers/morphing.py index d0e3b6ba8..dd8199757 100644 --- a/voicevox_engine/app/routers/morphing.py +++ b/voicevox_engine/app/routers/morphing.py @@ -23,7 +23,7 @@ synthesis_morphing_parameter as _synthesis_morphing_parameter, ) from voicevox_engine.morphing.morphing import synthesize_morphed_wave -from voicevox_engine.tts_pipeline.tts_engine import TTSEngineManager +from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager from voicevox_engine.utility.file_utility import try_delete_file # キャッシュを有効化 @@ -86,7 +86,8 @@ def _synthesis_morphing( 指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。 モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) # モーフィングが許可されないキャラクターペアを拒否する characters = metas_store.characters(core_version) diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index 8791141ef..6555c844c 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -38,6 +38,7 @@ Score, ) from voicevox_engine.tts_pipeline.tts_engine import ( + LATEST_VERSION, TalkSingInvalidInputError, TTSEngineManager, ) @@ -83,7 +84,8 @@ def audio_query( """ 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) accent_phrases = engine.create_accent_phrases(text, style_id) return AudioQuery( accent_phrases=accent_phrases, @@ -113,7 +115,8 @@ def audio_query_from_preset( """ 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) try: presets = preset_manager.load_presets() except PresetInputError as err: @@ -171,7 +174,8 @@ def accent_phrases( * アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。 * アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) if is_kana: try: return engine.create_accent_phrases_from_kana(text, style_id) @@ -192,7 +196,8 @@ def mora_data( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) return engine.update_length_and_pitch(accent_phrases, style_id) @router.post( @@ -205,7 +210,8 @@ def mora_length( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) return engine.update_length(accent_phrases, style_id) @router.post( @@ -218,7 +224,8 @@ def mora_pitch( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) return engine.update_pitch(accent_phrases, style_id) @router.post( @@ -245,7 +252,8 @@ def synthesis( ] = True, core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) wave = engine.synthesize_wave( query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak ) @@ -286,8 +294,9 @@ def cancellable_synthesis( detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。", ) try: + version = core_version or LATEST_VERSION f_name = cancellable_engine._synthesis_impl( - query, style_id, request, core_version=core_version + query, style_id, request, version=version ) except CancellableEngineInternalError as e: raise HTTPException(status_code=500, detail=str(e)) @@ -321,7 +330,8 @@ def multi_synthesis( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) sampling_rate = queries[0].outputSamplingRate with NamedTemporaryFile(delete=False) as f: @@ -363,7 +373,8 @@ def sing_frame_audio_query( """ 歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) try: phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume( score, style_id @@ -391,7 +402,8 @@ def sing_frame_volume( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[float]: - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) try: return engine.create_sing_volume_from_phoneme_and_f0( score, frame_audio_query.phonemes, frame_audio_query.f0, style_id @@ -419,7 +431,8 @@ def frame_synthesis( """ 歌唱音声合成を行います。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) try: wave = engine.frame_synthsize_wave(query, style_id) except TalkSingInvalidInputError as e: @@ -514,7 +527,8 @@ def initialize_speaker( 指定されたスタイルを初期化します。 実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) engine.initialize_synthesis(style_id, skip_reinit=skip_reinit) @router.get("/is_initialized_speaker", tags=["その他"]) @@ -525,7 +539,8 @@ def is_initialized_speaker( """ 指定されたスタイルが初期化されているかどうかを返します。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or LATEST_VERSION + engine = tts_engines.get_engine(version) return engine.is_synthesis_initialized(style_id) return router diff --git a/voicevox_engine/cancellable_engine.py b/voicevox_engine/cancellable_engine.py index 7e0dbcf7d..a812892f3 100644 --- a/voicevox_engine/cancellable_engine.py +++ b/voicevox_engine/cancellable_engine.py @@ -19,7 +19,7 @@ from .core.core_initializer import initialize_cores from .metas.Metas import StyleId from .model import AudioQuery -from .tts_pipeline.tts_engine import make_tts_engines_from_cores +from .tts_pipeline.tts_engine import LatestVersion, make_tts_engines_from_cores class CancellableEngineInternalError(Exception): @@ -149,7 +149,7 @@ def _synthesis_impl( query: AudioQuery, style_id: StyleId, request: Request, - core_version: str | None, + version: str | LatestVersion, ) -> str: """ 音声合成を行う関数 @@ -163,7 +163,7 @@ def _synthesis_impl( request: fastapi.Request 接続確立時に受け取ったものをそのまま渡せばよい https://fastapi.tiangolo.com/advanced/using-request-directly/ - core_version + version Returns ------- @@ -173,7 +173,7 @@ def _synthesis_impl( proc, sub_proc_con1 = self.procs_and_cons.get() self.watch_con_list.append((request, proc)) try: - sub_proc_con1.send((query, style_id, core_version)) + sub_proc_con1.send((query, style_id, version)) f_name = sub_proc_con1.recv() if isinstance(f_name, str): audio_file_name = f_name @@ -244,9 +244,9 @@ def start_synthesis_subprocess( assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。" while True: try: - query, style_id, core_version = sub_proc_con.recv() + query, style_id, version = sub_proc_con.recv() try: - _engine = tts_engines.get_engine(core_version) + _engine = tts_engines.get_engine(version) except Exception: # バージョンが見つからないエラー sub_proc_con.send("") diff --git a/voicevox_engine/tts_pipeline/tts_engine.py b/voicevox_engine/tts_pipeline/tts_engine.py index 8b551c41e..73eea1afe 100644 --- a/voicevox_engine/tts_pipeline/tts_engine.py +++ b/voicevox_engine/tts_pipeline/tts_engine.py @@ -2,6 +2,7 @@ import copy import math +from typing import Literal, TypeAlias import numpy as np from fastapi import HTTPException @@ -699,6 +700,10 @@ def frame_synthsize_wave( return wave +LatestVersion: TypeAlias = Literal["LATEST_VERSION"] +LATEST_VERSION: LatestVersion = "LATEST_VERSION" + + class TTSEngineManager: """TTS エンジンの集まりを一括管理するマネージャー""" @@ -716,10 +721,11 @@ def register_engine(self, engine: TTSEngine, version: str) -> None: """エンジンを登録する。""" self._engines[version] = engine - def get_engine(self, core_version: str | None) -> TTSEngine: + def get_engine(self, version: str | LatestVersion) -> TTSEngine: """指定バージョンのエンジンを取得する。""" - version = core_version or self._latest_version() - if version in self._engines: + if version == LATEST_VERSION: + return self._engines[self._latest_version()] + elif version in self._engines: return self._engines[version] raise HTTPException(status_code=422, detail="不明なバージョンです") From d2bb45b36ba6ff1403e5a899a6613ce1c0857110 Mon Sep 17 00:00:00 2001 From: tarepan Date: Thu, 27 Jun 2024 10:08:54 +0000 Subject: [PATCH 5/5] =?UTF-8?q?refactor:=20=E5=AE=9A=E6=95=B0=E3=82=92=20F?= =?UTF-8?q?inal=20=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- voicevox_engine/tts_pipeline/tts_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/voicevox_engine/tts_pipeline/tts_engine.py b/voicevox_engine/tts_pipeline/tts_engine.py index 73eea1afe..5cde61724 100644 --- a/voicevox_engine/tts_pipeline/tts_engine.py +++ b/voicevox_engine/tts_pipeline/tts_engine.py @@ -2,7 +2,7 @@ import copy import math -from typing import Literal, TypeAlias +from typing import Final, Literal, TypeAlias import numpy as np from fastapi import HTTPException @@ -701,7 +701,7 @@ def frame_synthsize_wave( LatestVersion: TypeAlias = Literal["LATEST_VERSION"] -LATEST_VERSION: LatestVersion = "LATEST_VERSION" +LATEST_VERSION: Final[LatestVersion] = "LATEST_VERSION" class TTSEngineManager: