-
Notifications
You must be signed in to change notification settings - Fork 206
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add: 速度ベンチマーク * refactor: `httpx.Client` を利用 Co-authored-by: sabonerune <[email protected]> * refactor: `httpx.Client` を利用 * refactor: client 設定を明確化 * fix: `utils` を `utility` へリネーム * add: `voicevox_dir` 引数を追加 * fix: 準備をリネーム * fix: req-res ベンチマークを詳細化 * fix: コンフリクト * fix: `engine_preparation` へ module docstring を追加 * fix: 昔の不正確なコメントを削除 * fix: warning 文へ原因を追加 * fix: サーバー引数の設定を docstring へ移植 * fix: ベンチマーク実行手順を明確化 * fix: 間接的ベンチマークの意図を明確化 * fix: コンフリクト --------- Co-authored-by: sabonerune <[email protected]>
- Loading branch information
1 parent
26022a9
commit 93e3347
Showing
6 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""VOICEVOX ENGINE へアクセス可能なクライアントの生成""" | ||
|
||
import warnings | ||
from pathlib import Path | ||
from typing import Literal | ||
|
||
import httpx | ||
from fastapi.testclient import TestClient | ||
|
||
from voicevox_engine.app.application import generate_app | ||
from voicevox_engine.core.core_initializer import initialize_cores | ||
from voicevox_engine.preset.PresetManager import PresetManager | ||
from voicevox_engine.setting.Setting import SettingHandler | ||
from voicevox_engine.tts_pipeline.tts_engine import make_tts_engines_from_cores | ||
from voicevox_engine.user_dict.user_dict import UserDictionary | ||
from voicevox_engine.utility.core_version_utility import get_latest_version | ||
|
||
|
||
def _generate_engine_fake_server(root_dir: Path) -> TestClient: | ||
core_manager = initialize_cores( | ||
voicevox_dir=root_dir, use_gpu=False, enable_mock=False | ||
) | ||
tts_engines = make_tts_engines_from_cores(core_manager) | ||
latest_core_version = get_latest_version(tts_engines.versions()) | ||
setting_loader = SettingHandler(Path("./not_exist.yaml")) | ||
preset_manager = PresetManager(Path("./presets.yaml")) | ||
user_dict = UserDictionary() | ||
|
||
app = generate_app( | ||
tts_engines=tts_engines, | ||
core_manager=core_manager, | ||
latest_core_version=latest_core_version, | ||
setting_loader=setting_loader, | ||
preset_manager=preset_manager, | ||
root_dir=root_dir, | ||
user_dict=user_dict, | ||
) | ||
return TestClient(app) | ||
|
||
|
||
ServerType = Literal["localhost", "fake"] | ||
|
||
|
||
def generate_client( | ||
server: ServerType, root_dir: Path | None | ||
) -> TestClient | httpx.Client: | ||
""" | ||
VOICEVOX ENGINE へアクセス可能なクライアントを生成する。 | ||
`server=localhost` では http://localhost:50021 へのクライアントを生成する。 | ||
`server=fake` ではネットワークを介さずレスポンスを返す疑似サーバーを生成する。 | ||
""" | ||
|
||
if server == "fake": | ||
if root_dir is None: | ||
warn_msg = "root_dirが未指定であるため、自動的に `VOICEVOX/vv-engine` を `root_dir` に設定します。" | ||
warnings.warn(warn_msg, stacklevel=2) | ||
root_dir = Path("VOICEVOX/vv-engine") | ||
return _generate_engine_fake_server(root_dir) | ||
elif server == "localhost": | ||
return httpx.Client(base_url="http://localhost:50021") | ||
else: | ||
raise Exception(f"{server} はサポートされていないサーバータイプです") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""エンジンへのリクエストにかかる時間の測定""" | ||
|
||
import argparse | ||
from pathlib import Path | ||
from test.benchmark.engine_preparation import ServerType, generate_client | ||
from test.benchmark.speed.utility import benchmark_time | ||
|
||
|
||
def benchmark_request(server: ServerType, root_dir: Path | None = None) -> float: | ||
""" | ||
エンジンへのリクエストにかかる時間を測定する。 | ||
`GET /` はエンジン内部処理が最小であるため、全話者分のリクエスト-レスポンス(ネットワーク処理部分)にかかる時間を擬似的に計測できる。 | ||
""" | ||
|
||
client = generate_client(server, root_dir) | ||
|
||
def execute() -> None: | ||
"""計測対象となる処理を実行する""" | ||
client.get("/", params={}) | ||
|
||
average_time = benchmark_time(execute, n_repeat=10) | ||
return average_time | ||
|
||
|
||
if __name__ == "__main__": | ||
# 実行コマンドは `python -m test.benchmark.speed.request` である。 | ||
# `server="localhost"` の場合、本ベンチマーク実行に先立ってエンジン起動が必要である。 | ||
# エンジン起動コマンドの一例として以下を示す。 | ||
# (別プロセスで)`python run.py --voicevox_dir=VOICEVOX/vv-engine` | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--voicevox_dir", type=Path) | ||
args = parser.parse_args() | ||
root_dir: Path | None = args.voicevox_dir | ||
|
||
result_fakeserve = benchmark_request(server="fake", root_dir=root_dir) | ||
result_localhost = benchmark_request(server="localhost", root_dir=root_dir) | ||
print("`GET /` fakeserve: {:.4f} sec".format(result_fakeserve)) | ||
print("`GET /` localhost: {:.4f} sec".format(result_localhost)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
"""話者に関係したリクエストにかかる時間の測定""" | ||
|
||
import argparse | ||
from pathlib import Path | ||
from test.benchmark.engine_preparation import ServerType, generate_client | ||
from test.benchmark.speed.utility import benchmark_time | ||
|
||
|
||
def benchmark_get_speakers(server: ServerType, root_dir: Path | None = None) -> float: | ||
"""`GET /speakers` にかかる時間を測定する。""" | ||
|
||
client = generate_client(server, root_dir) | ||
|
||
def execute() -> None: | ||
"""計測対象となる処理を実行する""" | ||
client.get("/speakers", params={}) | ||
|
||
average_time = benchmark_time(execute, n_repeat=10) | ||
return average_time | ||
|
||
|
||
def benchmark_get_speaker_info_all( | ||
server: ServerType, root_dir: Path | None = None | ||
) -> float: | ||
"""全話者への `GET /speaker_info` にかかる時間を測定する。""" | ||
|
||
client = generate_client(server, root_dir) | ||
|
||
# speaker_uuid 一覧を準備 | ||
response = client.get("/speakers", params={}) | ||
assert response.status_code == 200 | ||
speakers = response.json() | ||
speaker_uuids = list(map(lambda speaker: speaker["speaker_uuid"], speakers)) | ||
|
||
def execute() -> None: | ||
"""計測対象となる処理を実行する""" | ||
for speaker_uuid in speaker_uuids: | ||
client.get("/speaker_info", params={"speaker_uuid": speaker_uuid}) | ||
|
||
average_time = benchmark_time(execute, n_repeat=10) | ||
return average_time | ||
|
||
|
||
def benchmark_request_time_for_all_speakers( | ||
server: ServerType, root_dir: Path | None = None | ||
) -> float: | ||
""" | ||
全話者数と同じ回数の `GET /` にかかる時間を測定する。 | ||
`GET /` はエンジン内部処理が最小であるため、全話者分のリクエスト-レスポンス(ネットワーク処理部分)にかかる時間を擬似的に計測できる。 | ||
""" | ||
|
||
client = generate_client(server, root_dir) | ||
|
||
# speaker_uuid 一覧を準備 | ||
response = client.get("/speakers", params={}) | ||
assert response.status_code == 200 | ||
speakers = response.json() | ||
speaker_uuids = list(map(lambda speaker: speaker["speaker_uuid"], speakers)) | ||
|
||
def execute() -> None: | ||
"""計測対象となる処理を実行する""" | ||
for _ in speaker_uuids: | ||
client.get("/", params={}) | ||
|
||
average_time = benchmark_time(execute, n_repeat=10) | ||
return average_time | ||
|
||
|
||
if __name__ == "__main__": | ||
# 実行コマンドは `python -m test.benchmark.speed.speaker` である。 | ||
# `server="localhost"` の場合、本ベンチマーク実行に先立ってエンジン起動が必要である。 | ||
# エンジン起動コマンドの一例として以下を示す。 | ||
# (別プロセスで)`python run.py --voicevox_dir=VOICEVOX/vv-engine` | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--voicevox_dir", type=Path) | ||
args = parser.parse_args() | ||
root_dir: Path | None = args.voicevox_dir | ||
|
||
result_speakers_fakeserve = benchmark_get_speakers("fake", root_dir) | ||
result_speakers_localhost = benchmark_get_speakers("localhost", root_dir) | ||
print("`GET /speakers` fakeserve: {:.4f} sec".format(result_speakers_fakeserve)) | ||
print("`GET /speakers` localhost: {:.4f} sec".format(result_speakers_localhost)) | ||
|
||
_result_spk_infos_fakeserve = benchmark_get_speaker_info_all("fake", root_dir) | ||
_result_spk_infos_localhost = benchmark_get_speaker_info_all("localhost", root_dir) | ||
result_spk_infos_fakeserve = "{:.3f}".format(_result_spk_infos_fakeserve) | ||
result_spk_infos_localhost = "{:.3f}".format(_result_spk_infos_localhost) | ||
print(f"全話者 `GET /speaker_info` fakeserve: {result_spk_infos_fakeserve} sec") | ||
print(f"全話者 `GET /speaker_info` localhost: {result_spk_infos_localhost} sec") | ||
|
||
req_time_all_fake = benchmark_request_time_for_all_speakers("fake", root_dir) | ||
req_time_all_local = benchmark_request_time_for_all_speakers("localhost", root_dir) | ||
print("全話者 `GET /` fakeserve: {:.3f} sec".format(req_time_all_fake)) | ||
print("全話者 `GET /` localhost: {:.3f} sec".format(req_time_all_local)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""速度ベンチマーク用のユーティリティ""" | ||
|
||
import time | ||
from typing import Callable | ||
|
||
|
||
def benchmark_time( | ||
target_function: Callable[[], None], n_repeat: int, sec_sleep: float = 1.0 | ||
) -> float: | ||
"""対象関数の平均実行時間を計測する。""" | ||
scores: list[float] = [] | ||
for _ in range(n_repeat): | ||
start = time.perf_counter() | ||
target_function() | ||
end = time.perf_counter() | ||
scores += [end - start] | ||
time.sleep(sec_sleep) | ||
average = sum(scores) / len(scores) | ||
return average |