Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: TTSEngine メソッド引数に CoreAdapter を追加 #1392

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/unit/test_mock_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,20 @@ def _gen_accent_phrases() -> list[AccentPhrase]:
def test_update_length() -> None:
"""`.update_length()` がエラー無く生成をおこなう"""
engine = MockTTSEngine()
engine.update_length(_gen_accent_phrases(), StyleId(0))
engine.update_length(engine._core, _gen_accent_phrases(), StyleId(0))


def test_update_pitch() -> None:
"""`.update_pitch()` がエラー無く生成をおこなう"""
engine = MockTTSEngine()
engine.update_pitch(_gen_accent_phrases(), StyleId(0))
engine.update_pitch(engine._core, _gen_accent_phrases(), StyleId(0))


def test_synthesize_wave() -> None:
"""`.synthesize_wave()` がエラー無く生成をおこなう"""
engine = MockTTSEngine()
engine.synthesize_wave(
engine._core,
AudioQuery(
accent_phrases=_gen_accent_phrases(),
speedScale=1,
Expand Down
42 changes: 26 additions & 16 deletions test/unit/tts_pipeline/test_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_update_length() -> None:
# Inputs
hello_hiho = _gen_hello_hiho_accent_phrases()
# Indirect Outputs(yukarin_sに渡される値)
tts_engine.update_length(hello_hiho, StyleId(1))
tts_engine.update_length(tts_engine._core, hello_hiho, StyleId(1))
yukarin_s_args = _yukarin_s_mock.call_args[1]
list_length = yukarin_s_args["length"]
phoneme_list = yukarin_s_args["phoneme_list"]
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_update_pitch() -> None:
# Inputs
phrases: list = []
# Outputs
result = tts_engine.update_pitch(phrases, StyleId(1))
result = tts_engine.update_pitch(tts_engine._core, phrases, StyleId(1))
# Expects
true_result: list = []
# Tests
Expand All @@ -261,7 +261,7 @@ def test_update_pitch() -> None:
# Inputs
hello_hiho = _gen_hello_hiho_accent_phrases()
# Indirect Outputs(yukarin_saに渡される値)
tts_engine.update_pitch(hello_hiho, StyleId(1))
tts_engine.update_pitch(tts_engine._core, hello_hiho, StyleId(1))
yukarin_sa_args = _yukarin_sa_mock.call_args[1]
list_length = yukarin_sa_args["length"]
vowel_phoneme_list = yukarin_sa_args["vowel_phoneme_list"][0]
Expand Down Expand Up @@ -305,7 +305,9 @@ def test_create_accent_phrases_toward_unknown() -> None:
"dummy", text_to_features=stub_unknown_features_koxx
)
with pytest.raises(ValueError) as e:
accent_phrases = engine.update_length_and_pitch(accent_phrases, StyleId(0))
accent_phrases = engine.update_length_and_pitch(
engine._core, accent_phrases, StyleId(0)
)
assert str(e.value) == "tuple.index(x): x not in tuple"


Expand All @@ -315,7 +317,7 @@ def test_mocked_update_length_output(snapshot_json: SnapshotAssertion) -> None:
tts_engine = TTSEngine(MockCoreWrapper())
hello_hiho = _gen_hello_hiho_accent_phrases()
# Outputs
result = tts_engine.update_length(hello_hiho, StyleId(1))
result = tts_engine.update_length(tts_engine._core, hello_hiho, StyleId(1))
# Tests
assert snapshot_json == round_floats(pydantic_to_native_type(result), round_value=2)

Expand All @@ -326,7 +328,7 @@ def test_mocked_update_pitch_output(snapshot_json: SnapshotAssertion) -> None:
tts_engine = TTSEngine(MockCoreWrapper())
hello_hiho = _gen_hello_hiho_accent_phrases()
# Outputs
result = tts_engine.update_pitch(hello_hiho, StyleId(1))
result = tts_engine.update_pitch(tts_engine._core, hello_hiho, StyleId(1))
# Tests
assert snapshot_json == round_floats(pydantic_to_native_type(result), round_value=2)

Expand All @@ -339,7 +341,9 @@ def test_mocked_update_length_and_pitch_output(
tts_engine = TTSEngine(MockCoreWrapper())
hello_hiho = _gen_hello_hiho_accent_phrases()
# Outputs
result = tts_engine.update_length_and_pitch(hello_hiho, StyleId(1))
result = tts_engine.update_length_and_pitch(
tts_engine._core, hello_hiho, StyleId(1)
)
# Tests
assert snapshot_json == round_floats(pydantic_to_native_type(result), round_value=2)

Expand All @@ -352,7 +356,7 @@ def test_mocked_create_accent_phrases_output(
tts_engine = TTSEngine(MockCoreWrapper())
hello_hiho = _gen_hello_hiho_text()
# Outputs
result = tts_engine.create_accent_phrases(hello_hiho, StyleId(1))
result = tts_engine.create_accent_phrases(tts_engine._core, hello_hiho, StyleId(1))
# Tests
assert snapshot_json == round_floats(pydantic_to_native_type(result), round_value=2)

Expand All @@ -365,7 +369,9 @@ def test_mocked_create_accent_phrases_from_kana_output(
tts_engine = TTSEngine(MockCoreWrapper())
hello_hiho = _gen_hello_hiho_kana()
# Outputs
result = tts_engine.create_accent_phrases_from_kana(hello_hiho, StyleId(1))
result = tts_engine.create_accent_phrases_from_kana(
tts_engine._core, hello_hiho, StyleId(1)
)
# Tests
assert snapshot_json == round_floats(pydantic_to_native_type(result), round_value=2)

Expand All @@ -376,7 +382,7 @@ def test_mocked_synthesize_wave_output(snapshot_json: SnapshotAssertion) -> None
tts_engine = TTSEngine(MockCoreWrapper())
hello_hiho = _gen_hello_hiho_query()
# Outputs
result = tts_engine.synthesize_wave(hello_hiho, StyleId(1))
result = tts_engine.synthesize_wave(tts_engine._core, hello_hiho, StyleId(1))
# Tests
assert snapshot_json == round_floats(result.tolist(), round_value=2)

Expand All @@ -392,11 +398,11 @@ def test_mocked_create_sing_volume_from_phoneme_and_f0_output(
tts_engine = TTSEngine(MockCoreWrapper())
doremi_srore = _gen_doremi_score()
phonemes, f0s, _ = tts_engine.create_sing_phoneme_and_f0_and_volume(
doremi_srore, StyleId(1)
tts_engine._core, doremi_srore, StyleId(1)
)
# Outputs
result = tts_engine.create_sing_volume_from_phoneme_and_f0(
doremi_srore, phonemes, f0s, StyleId(1)
tts_engine._core, doremi_srore, phonemes, f0s, StyleId(1)
)
# Tests
assert snapshot_json == round_floats(result, round_value=2)
Expand All @@ -413,7 +419,9 @@ def test_mocked_synthesize_wave_from_score_output(
tts_engine = TTSEngine(MockCoreWrapper())
doremi_srore = _gen_doremi_score()
# Outputs
result = tts_engine.create_sing_phoneme_and_f0_and_volume(doremi_srore, StyleId(1))
result = tts_engine.create_sing_phoneme_and_f0_and_volume(
tts_engine._core, doremi_srore, StyleId(1)
)
# Tests
assert snapshot_json(name="query") == round_floats(
pydantic_to_native_type(result), round_value=2
Expand All @@ -430,7 +438,9 @@ def test_mocked_synthesize_wave_from_score_output(
outputStereo=False,
)
# Outputs
result_wave = tts_engine.frame_synthsize_wave(doremi_query, StyleId(1))
result_wave = tts_engine.frame_synthsize_wave(
tts_engine._core, doremi_query, StyleId(1)
)
# Tests
assert snapshot_json(name="wave") == round_floats(
result_wave.tolist(), round_value=2
Expand Down Expand Up @@ -527,7 +537,7 @@ def create_synthesis_test_base(
(https://github.com/VOICEVOX/voicevox_engine/issues/272#issuecomment-1022610866)
"""
tts_engine = TTSEngine(core=MockCoreWrapper())
inputs = tts_engine.create_accent_phrases(text, StyleId(1))
inputs = tts_engine.create_accent_phrases(tts_engine._core, text, StyleId(1))
outputs = apply_interrogative_upspeak(inputs, enable_interrogative_upspeak)
assert expected == outputs, f"case(text:{text})"

Expand All @@ -540,7 +550,7 @@ def test_create_accent_phrases() -> None:
text = "これはありますか?"
expected = koreha_arimasuka_base_expected()
expected[-1].is_interrogative = True
actual = tts_engine.create_accent_phrases(text, StyleId(1))
actual = tts_engine.create_accent_phrases(tts_engine._core, text, StyleId(1))
assert expected == actual, f"case(text:{text})"


Expand Down
37 changes: 25 additions & 12 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def audio_query(
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
accent_phrases = engine.create_accent_phrases(text, style_id)
accent_phrases = engine.create_accent_phrases(engine._core, text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=1,
Expand Down Expand Up @@ -130,7 +130,9 @@ def audio_query_from_preset(
status_code=422, detail="該当するプリセットIDが見つかりません"
)

accent_phrases = engine.create_accent_phrases(text, selected_preset.style_id)
accent_phrases = engine.create_accent_phrases(
engine._core, text, selected_preset.style_id
)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=selected_preset.speedScale,
Expand Down Expand Up @@ -173,13 +175,15 @@ def accent_phrases(
engine = tts_engines.get_engine(core_version)
if is_kana:
try:
return engine.create_accent_phrases_from_kana(text, style_id)
return engine.create_accent_phrases_from_kana(
engine._core, text, style_id
)
except ParseKanaError as err:
raise HTTPException(
status_code=400, detail=ParseKanaBadRequest(err).dict()
)
else:
return engine.create_accent_phrases(text, style_id)
return engine.create_accent_phrases(engine._core, text, style_id)

@router.post(
"/mora_data",
Expand All @@ -192,7 +196,7 @@ def mora_data(
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
return engine.update_length_and_pitch(accent_phrases, style_id)
return engine.update_length_and_pitch(engine._core, accent_phrases, style_id)

@router.post(
"/mora_length",
Expand All @@ -205,7 +209,7 @@ def mora_length(
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
return engine.update_length(accent_phrases, style_id)
return engine.update_length(engine._core, accent_phrases, style_id)

@router.post(
"/mora_pitch",
Expand All @@ -218,7 +222,7 @@ def mora_pitch(
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
return engine.update_pitch(accent_phrases, style_id)
return engine.update_pitch(engine._core, accent_phrases, style_id)

@router.post(
"/synthesis",
Expand Down Expand Up @@ -246,7 +250,10 @@ def synthesis(
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
engine._core,
query,
style_id,
enable_interrogative_upspeak=enable_interrogative_upspeak,
)

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -333,7 +340,9 @@ def multi_synthesis(
)

with TemporaryFile() as wav_file:
wave = engine.synthesize_wave(queries[i], style_id)
wave = engine.synthesize_wave(
engine._core, queries[i], style_id
)
soundfile.write(
file=wav_file,
data=wave,
Expand Down Expand Up @@ -366,7 +375,7 @@ def sing_frame_audio_query(
core = core_manager.get_core(core_version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
engine._core, score, style_id
)
except TalkSingInvalidInputError as e:
raise HTTPException(status_code=400, detail=str(e))
Expand Down Expand Up @@ -394,7 +403,11 @@ def sing_frame_volume(
engine = tts_engines.get_engine(core_version)
try:
return engine.create_sing_volume_from_phoneme_and_f0(
score, frame_audio_query.phonemes, frame_audio_query.f0, style_id
engine._core,
score,
frame_audio_query.phonemes,
frame_audio_query.f0,
style_id,
)
except TalkSingInvalidInputError as e:
raise HTTPException(status_code=400, detail=str(e))
Expand All @@ -421,7 +434,7 @@ def frame_synthesis(
"""
engine = tts_engines.get_engine(core_version)
try:
wave = engine.frame_synthsize_wave(query, style_id)
wave = engine.frame_synthsize_wave(engine._core, query, style_id)
except TalkSingInvalidInputError as e:
raise HTTPException(status_code=400, detail=str(e))

Expand Down
2 changes: 1 addition & 1 deletion voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def start_synthesis_subprocess(
continue
# FIXME: enable_interrogative_upspeakフラグをWebAPIから受け渡してくる
wave = _engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=False
_engine._core, query, style_id, enable_interrogative_upspeak=False
)
with NamedTemporaryFile(delete=False) as f:
soundfile.write(
Expand Down
10 changes: 7 additions & 3 deletions voicevox_engine/dev/tts_engine/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pyopenjtalk import tts
from soxr import resample

from voicevox_engine.core.core_adapter import CoreAdapter

from ...metas.Metas import StyleId
from ...model import AudioQuery
from ...tts_pipeline.tts_engine import TTSEngine, to_flatten_moras
Expand All @@ -20,8 +22,9 @@ class MockTTSEngine(TTSEngine):
def __init__(self) -> None:
super().__init__(MockCoreWrapper())

@staticmethod
def synthesize_wave(
self,
core: CoreAdapter,
query: AudioQuery,
style_id: StyleId,
enable_interrogative_upspeak: bool = True,
Expand All @@ -34,14 +37,15 @@ def synthesize_wave(
flatten_moras = to_flatten_moras(query.accent_phrases)
kana_text = "".join([mora.text for mora in flatten_moras])

wave = self.forward(kana_text)
wave = MockTTSEngine.forward(kana_text)

# volume
wave *= query.volumeScale

return wave

def forward(self, text: str, **kwargs: dict[str, Any]) -> NDArray[np.float32]:
@staticmethod
def forward(text: str, **kwargs: dict[str, Any]) -> NDArray[np.float32]:
"""
forward tts via pyopenjtalk.tts()
参照→TTSEngine のdocstring [Mock]
Expand Down
8 changes: 6 additions & 2 deletions voicevox_engine/morphing/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ def synthesis_morphing_parameter(
# WORLDに掛けるため合成はモノラルで行う
query.outputStereo = False

base_wave = engine.synthesize_wave(query, base_style_id).astype(np.double)
target_wave = engine.synthesize_wave(query, target_style_id).astype(np.double)
base_wave = engine.synthesize_wave(engine._core, query, base_style_id).astype(
np.double
)
target_wave = engine.synthesize_wave(engine._core, query, target_style_id).astype(
np.double
)

fs = query.outputSamplingRate
frame_period = 1.0
Expand Down
Loading