Skip to content

Commit

Permalink
Revert "refactor: data modelを分離 (#256)" (#266)
Browse files Browse the repository at this point in the history
This reverts commit ff00ad6.
  • Loading branch information
takana-v authored Jan 4, 2022
1 parent 2e3af4f commit 16b1d5a
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 830 deletions.
205 changes: 69 additions & 136 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,26 @@
from fastapi.params import Query
from starlette.responses import FileResponse

from voicevox_engine import model
from voicevox_engine.cancellable_engine import CancellableEngine
from voicevox_engine.kana_parser import create_kana, parse_kana
from voicevox_engine.model import ParseKanaError
from voicevox_engine.model import (
AccentPhrase,
AudioQuery,
ParseKanaBadRequest,
ParseKanaError,
Speaker,
SpeakerInfo,
)
from voicevox_engine.morphing import synthesis_morphing
from voicevox_engine.morphing import (
synthesis_morphing_parameter as _synthesis_morphing_parameter,
)
from voicevox_engine.preset import PresetLoader
from voicevox_engine.preset import Preset, PresetLoader
from voicevox_engine.synthesis_engine import SynthesisEngineBase, make_synthesis_engine
from voicevox_engine.synthesis_engine.synthesis_engine_base import (
adjust_interrogative_accent_phrases,
)
from voicevox_engine.utility import ConnectBase64WavesException, connect_base64_waves
from voicevox_engine.webapi.fastapi_model import (
AccentPhrase,
AudioQuery,
ParseKanaBadRequest,
Preset,
Speaker,
SpeakerInfo,
)

"""
voicevox_enbine/model.pyで定義されている型は内部で使用する型なので、リクエスト及びレスポンスを行う際に使用してはならない。
リクエスト・レスポンスで使用する型はvoicevox_engine/webapi/fastapi_model.pyで定義されている型を使用し、
内部で使用している型から(or に)変換すること
"""


def b64encode_str(s):
Expand Down Expand Up @@ -100,7 +92,7 @@ def audio_query(
text: str,
speaker: int,
enable_interrogative: bool = enable_interrogative_query_param(), # noqa B008,
) -> AudioQuery:
):
"""
クエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
Expand All @@ -109,19 +101,17 @@ def audio_query(
speaker_id=speaker,
enable_interrogative=enable_interrogative,
)
return AudioQuery.from_engine(
model.AudioQuery(
accent_phrases=accent_phrases,
speedScale=1,
pitchScale=0,
intonationScale=1,
volumeScale=1,
prePhonemeLength=0.1,
postPhonemeLength=0.1,
outputSamplingRate=default_sampling_rate,
outputStereo=False,
kana=create_kana(accent_phrases),
)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=1,
pitchScale=0,
intonationScale=1,
volumeScale=1,
prePhonemeLength=0.1,
postPhonemeLength=0.1,
outputSamplingRate=default_sampling_rate,
outputStereo=False,
kana=create_kana(accent_phrases),
)

@app.post(
Expand All @@ -134,7 +124,7 @@ def audio_query_from_preset(
text: str,
preset_id: int,
enable_interrogative: bool = enable_interrogative_query_param(), # noqa B008,
) -> AudioQuery:
):
"""
クエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
Expand All @@ -153,19 +143,17 @@ def audio_query_from_preset(
speaker_id=selected_preset.style_id,
enable_interrogative=enable_interrogative,
)
return AudioQuery.from_engine(
model.AudioQuery(
accent_phrases=accent_phrases,
speedScale=selected_preset.speedScale,
pitchScale=selected_preset.pitchScale,
intonationScale=selected_preset.intonationScale,
volumeScale=selected_preset.volumeScale,
prePhonemeLength=selected_preset.prePhonemeLength,
postPhonemeLength=selected_preset.postPhonemeLength,
outputSamplingRate=default_sampling_rate,
outputStereo=False,
kana=create_kana(accent_phrases),
)
return AudioQuery(
accent_phrases=accent_phrases,
speedScale=selected_preset.speedScale,
pitchScale=selected_preset.pitchScale,
intonationScale=selected_preset.intonationScale,
volumeScale=selected_preset.volumeScale,
prePhonemeLength=selected_preset.prePhonemeLength,
postPhonemeLength=selected_preset.postPhonemeLength,
outputSamplingRate=default_sampling_rate,
outputStereo=False,
kana=create_kana(accent_phrases),
)

@app.post(
Expand All @@ -185,7 +173,7 @@ def accent_phrases(
speaker: int,
is_kana: bool = False,
enable_interrogative: bool = enable_interrogative_query_param(), # noqa B008,
) -> List[AccentPhrase]:
):
"""
テキストからアクセント句を得ます。
is_kanaが`true`のとき、テキストは次のようなAquesTalkライクな記法に従う読み仮名として処理されます。デフォルトは`false`です。
Expand All @@ -208,90 +196,46 @@ def accent_phrases(
accent_phrases=accent_phrases, speaker_id=speaker
)

return [
AccentPhrase.from_engine(accent_phrase)
for accent_phrase in (
adjust_interrogative_accent_phrases(
accent_phrases,
interrogative_accent_phrase_marks,
enable_interrogative,
)
)
]
return adjust_interrogative_accent_phrases(
accent_phrases, interrogative_accent_phrase_marks, enable_interrogative
)
else:
return [
AccentPhrase.from_engine(accent_phrase)
for accent_phrase in (
engine.create_accent_phrases(
text,
speaker_id=speaker,
enable_interrogative=enable_interrogative,
)
)
]
return engine.create_accent_phrases(
text,
speaker_id=speaker,
enable_interrogative=enable_interrogative,
)

@app.post(
"/mora_data",
response_model=List[AccentPhrase],
tags=["クエリ編集"],
summary="アクセント句から音高・音素長を得る",
)
def mora_data(
accent_phrases: List[AccentPhrase], speaker: int
) -> List[AccentPhrase]:
return [
AccentPhrase.from_engine(accent_phrase)
for accent_phrase in (
engine.replace_mora_data(
accent_phrases=[
accent_phrase.to_engine() for accent_phrase in accent_phrases
],
speaker_id=speaker,
)
)
]
def mora_data(accent_phrases: List[AccentPhrase], speaker: int):
return engine.replace_mora_data(accent_phrases, speaker_id=speaker)

@app.post(
"/mora_length",
response_model=List[AccentPhrase],
tags=["クエリ編集"],
summary="アクセント句から音素長を得る",
)
def mora_length(
accent_phrases: List[AccentPhrase], speaker: int
) -> List[AccentPhrase]:
return [
AccentPhrase.from_engine(accent_phrase)
for accent_phrase in (
engine.replace_phoneme_length(
accent_phrases=[
accent_phrase.to_engine() for accent_phrase in accent_phrases
],
speaker_id=speaker,
)
)
]
def mora_length(accent_phrases: List[AccentPhrase], speaker: int):
return engine.replace_phoneme_length(
accent_phrases=accent_phrases, speaker_id=speaker
)

@app.post(
"/mora_pitch",
response_model=List[AccentPhrase],
tags=["クエリ編集"],
summary="アクセント句から音高を得る",
)
def mora_pitch(
accent_phrases: List[AccentPhrase], speaker: int
) -> List[AccentPhrase]:
return [
AccentPhrase.from_engine(accent_phrase)
for accent_phrase in (
engine.replace_mora_pitch(
accent_phrases=[
accent_phrase.to_engine() for accent_phrase in accent_phrases
],
speaker_id=speaker,
)
)
]
def mora_pitch(accent_phrases: List[AccentPhrase], speaker: int):
return engine.replace_mora_pitch(
accent_phrases=accent_phrases, speaker_id=speaker
)

@app.post(
"/synthesis",
Expand All @@ -306,11 +250,8 @@ def mora_pitch(
tags=["音声合成"],
summary="音声合成する",
)
def synthesis(query: AudioQuery, speaker: int) -> FileResponse:
wave = engine.synthesis(
query=query.to_engine(),
speaker_id=speaker,
)
def synthesis(query: AudioQuery, speaker: int):
wave = engine.synthesis(query=query, speaker_id=speaker)

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
Expand All @@ -332,18 +273,14 @@ def synthesis(query: AudioQuery, speaker: int) -> FileResponse:
tags=["音声合成"],
summary="音声合成する(キャンセル可能)",
)
def cancellable_synthesis(
query: AudioQuery, speaker: int, request: Request
) -> FileResponse:
def cancellable_synthesis(query: AudioQuery, speaker: int, request: Request):
if not args.enable_cancellable_synthesis:
raise HTTPException(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
f_name = cancellable_engine.synthesis(
query=query.to_engine(),
speaker_id=speaker,
request=request,
query=query, speaker_id=speaker, request=request
)

return FileResponse(f_name, media_type="audio/wav")
Expand All @@ -363,7 +300,7 @@ def cancellable_synthesis(
tags=["音声合成"],
summary="複数まとめて音声合成する",
)
def multi_synthesis(queries: List[AudioQuery], speaker: int) -> FileResponse:
def multi_synthesis(queries: List[AudioQuery], speaker: int):
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
Expand All @@ -379,10 +316,7 @@ def multi_synthesis(queries: List[AudioQuery], speaker: int) -> FileResponse:

with TemporaryFile() as wav_file:

wave = engine.synthesis(
query=queries[i].to_engine(),
speaker_id=speaker,
)
wave = engine.synthesis(query=queries[i], speaker_id=speaker)
soundfile.write(
file=wav_file,
data=wave,
Expand Down Expand Up @@ -412,7 +346,7 @@ def _synthesis_morphing(
base_speaker: int,
target_speaker: int,
morph_rate: float = Query(..., ge=0.0, le=1.0), # noqa: B008
) -> FileResponse:
):
"""
指定された2人の話者で音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースの話者、1.0でターゲットの話者に近づきます。
Expand All @@ -421,7 +355,7 @@ def _synthesis_morphing(
# 生成したパラメータはキャッシュされる
morph_param = synthesis_morphing_parameter(
engine=engine,
query=query.to_engine(),
query=query,
base_speaker=base_speaker,
target_speaker=target_speaker,
)
Expand Down Expand Up @@ -455,14 +389,14 @@ def _synthesis_morphing(
tags=["その他"],
summary="base64エンコードされた複数のwavデータを一つに結合する",
)
def connect_waves(waves: List[str]) -> FileResponse:
def connect_waves(waves: List[str]):
"""
base64エンコードされたwavデータを一纏めにし、wavファイルで返します。
"""
try:
waves_nparray, sampling_rate = connect_base64_waves(waves)
except ConnectBase64WavesException as err:
raise HTTPException(status_code=422, detail=str(err))
return HTTPException(status_code=422, detail=str(err))

with NamedTemporaryFile(delete=False) as f:
soundfile.write(
Expand All @@ -475,7 +409,7 @@ def connect_waves(waves: List[str]) -> FileResponse:
return FileResponse(f.name, media_type="audio/wav")

@app.get("/presets", response_model=List[Preset], tags=["その他"])
def get_presets() -> List[Preset]:
def get_presets():
"""
エンジンが保持しているプリセットの設定を返します
Expand All @@ -487,21 +421,21 @@ def get_presets() -> List[Preset]:
presets, err_detail = preset_loader.load_presets()
if err_detail:
raise HTTPException(status_code=422, detail=err_detail)
return [preset.to_engine() for preset in presets]
return presets

@app.get("/version", tags=["その他"])
def version() -> str:
return (root_dir / "VERSION.txt").read_text()

@app.get("/speakers", response_model=List[Speaker], tags=["その他"])
def speakers() -> Response:
def speakers():
return Response(
content=engine.speakers,
media_type="application/json",
)

@app.get("/speaker_info", response_model=SpeakerInfo, tags=["その他"])
def speaker_info(speaker_uuid: str) -> SpeakerInfo:
def speaker_info(speaker_uuid: str):
"""
指定されたspeaker_uuidに関する情報をjson形式で返します。
画像や音声はbase64エンコードされたものが返されます。
Expand Down Expand Up @@ -540,17 +474,16 @@ def speaker_info(speaker_uuid: str) -> SpeakerInfo:
for j in range(3)
]
style_infos.append(
model.StyleInfo(id=id, icon=icon, voice_samples=voice_samples)
{"id": id, "icon": icon, "voice_samples": voice_samples}
)
except FileNotFoundError:
import traceback

traceback.print_exc()
raise HTTPException(status_code=500, detail="追加情報が見つかりませんでした")

return SpeakerInfo.from_engine(
model.SpeakerInfo(policy=policy, portrait=portrait, style_infos=style_infos)
)
ret_data = {"policy": policy, "portrait": portrait, "style_infos": style_infos}
return ret_data

return app

Expand Down
Loading

0 comments on commit 16b1d5a

Please sign in to comment.