Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DirectML版の推論速度が遅いっぽい #422

Closed
1 of 3 tasks
Hiroshiba opened this issue Feb 3, 2023 · 5 comments
Closed
1 of 3 tasks

DirectML版の推論速度が遅いっぽい #422

Hiroshiba opened this issue Feb 3, 2023 · 5 comments

Comments

@Hiroshiba
Copy link
Member

Hiroshiba commented Feb 3, 2023

不具合の内容

エンジン側で発覚し、コアで測定した結果どうやらコアの時点で遅いっぽいことがわかりました。

現象・ログ

0.14.0のコア(wheel)で、CPU/GPUで2回ずつ計測

CPU
Total time: 8.854429006576538 seconds

GPU
Total time: 7.63310980796814 seconds

CPU
Total time: 11.323852300643921 seconds

GPU
Total time: 7.5207672119140625 seconds

0.13.3を用いた場合、(エンジン経由で)2秒で完了したのでおそらく0.14で遅くなったのかなと思っています。

再現手順

0.14.0のDirectML版で実行してみる

VOICEVOXのバージョン

0.14.0

OSの種類/ディストリ/バージョン

  • Windows
  • macOS
  • Linux

その他

検証コード(wheel版をインストールしました)

import dataclasses
import json
import logging
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import Tuple

import voicevox_core
from voicevox_core import AccelerationMode, AudioQuery, VoicevoxCore

SPEAKER_ID = 0


def tts(core, text: str):
    audio_query = core.audio_query(text, SPEAKER_ID)
    wav = core.synthesis(audio_query, SPEAKER_ID)
    return wav


def main() -> None:
    logging.basicConfig(
        format="[%(levelname)s] %(filename)s: %(message)s", level="DEBUG"
    )
    logger = logging.getLogger(__name__)

    (acceleration_mode, open_jtalk_dict_dir, text, out) = parse_args()

    core = VoicevoxCore(
        acceleration_mode=acceleration_mode, open_jtalk_dict_dir=open_jtalk_dict_dir
    )
    core.load_model(SPEAKER_ID)

    tts(core, "テスト実行")

    start_second = time.time()
    for i in range(10):
        wav = tts(core, text)
    end_second = time.time()

    logger.info("%s", f"Total time: {end_second - start_second} seconds")


def parse_args() -> Tuple[AccelerationMode, Path, str, Path]:
    argparser = ArgumentParser()
    argparser.add_argument(
        "--mode",
        default="AUTO",
        type=AccelerationMode,
        help='モード ("AUTO", "CPU", "GPU")',
    )
    argparser.add_argument(
        "--open_jtalk_dict_dir",
        type=Path,
        default="voicevox_core/open_jtalk_dic_utf_8-1.11/",
        help="Open JTalkの辞書ディレクトリ",
    )
    argparser.add_argument(
        "text",
        help="読み上げさせたい文章",
    )
    argparser.add_argument(
        "--out",
        type=Path,
        help="出力wavファイルのパス",
        default="/tmp/out.wav",
    )
    args = argparser.parse_args()
    return (args.mode, args.open_jtalk_dict_dir, args.text, args.out)


def display_as_json(audio_query: AudioQuery) -> str:
    return json.dumps(dataclasses.asdict(audio_query), ensure_ascii=False)


if __name__ == "__main__":
    main()
@Hiroshiba
Copy link
Member Author

Hiroshiba commented Feb 3, 2023

DirectML利用時はVRAM使用量は0.13のときと同様程度に上がっていました。

またDirectML版実行時にこのようなログが出ていました。
(一部CPUとして利用されている・・・・・・?)

[INFO] lib.rs: "Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf."

onnxモデルは0.14.0と0.13.3で変更がないはずですが、念のために一旦確認してみます。

@Hiroshiba
Copy link
Member Author

onnxモデルは0.14.0と0.13.3で変更がないはずですが、念のために一旦確認してみます。

こちら変更がありました。ログを出さないようにするために予め最適化したonnxを保存するようにしたためでした。

続いて0.13.3のときと0.14.0のときのonnxモデルを用いて、今のRust版コードを使って速度を比較したいと思います。

@Hiroshiba
Copy link
Member Author

わかりました、onnxの最適化が問題っぽかったです!!
CUDA版で推論に失敗している例もあるのですが、もしかしたらこのあたりに起因してるかも・・・?

@Hiroshiba
Copy link
Member Author

Hiroshiba commented Feb 3, 2023

原因がわかりました。
よくわかりませんが、たぶんCPUプロバイダーとして最適化をかけたonnxを使っているのが原因っぽいです。

そもそも最適化をかけているのはコアのログが大量に出るのを防ぐためでした。
ログが出ているのはどうやらpredict_intonationで、これはCPUで利用するため最適化をかけても問題なさそうでした。
問題はこの処理をDirectML上で実行しているdecodeでもやっていることでした。
なのでdecode以外に最適化をかけておくのが良さそうでした。

(ちなみに最適化を予めかけなくても、onnxruntimeが読み込んだときに勝手に最適化してくれます。あくまでログを出ないようにするために事前に最適化をかけていた感じです。。)

Rust用TTSテストコードメモ
    // voicevox_ttsのテスト
    #[rstest]
    #[case(
        "hello,hello,hello,hello,hello",
        0,
        VoicevoxTtsOptions {
            kana: false,
            enable_interrogative_upspeak: false,
        },
        VoicevoxResultCode::VOICEVOX_RESULT_OK
    )]
    fn voicevox_tts_works(
        #[case] text: &str,
        #[case] speaker_id: u32,
        #[case] options: VoicevoxTtsOptions,
        #[case] expected: VoicevoxResultCode,
    ) {
        let mut initialize_options = voicevox_make_default_initialize_options();
        initialize_options.open_jtalk_dict_dir =
            "C:\\Users\\hihok\\Github\\voicevox_core\\hiho_voicevox_core\\open_jtalk_dic_utf_8-1.11"
                .as_ptr() as *const c_char;
        initialize_options.acceleration_mode =
            VoicevoxAccelerationMode::VOICEVOX_ACCELERATION_MODE_GPU;

        let actual = voicevox_initialize(initialize_options);
        assert_eq!(VoicevoxResultCode::VOICEVOX_RESULT_OK, actual);

        let actual = voicevox_load_model(0);
        assert_eq!(VoicevoxResultCode::VOICEVOX_RESULT_OK, actual);

        let mut output_wav_length: usize = 0;
        let mut output_wav: *mut u8 = std::ptr::null_mut();

        // とりあえず1回
        let actual = unsafe {
            voicevox_tts(
                text.as_ptr() as *const c_char,
                speaker_id,
                VoicevoxTtsOptions {
                    kana: options.kana,
                    enable_interrogative_upspeak: options.enable_interrogative_upspeak,
                },
                &mut output_wav_length as *mut usize,
                &mut output_wav as *mut *mut u8,
            )
        };
        assert_eq!(expected, actual);

        // 10回やって時間測定
        let start = std::time::Instant::now();
        for _ in 0..10 {
            let actual = unsafe {
                voicevox_tts(
                    text.as_ptr() as *const c_char,
                    speaker_id,
                    VoicevoxTtsOptions {
                        kana: options.kana,
                        enable_interrogative_upspeak: options.enable_interrogative_upspeak,
                    },
                    &mut output_wav_length as *mut usize,
                    &mut output_wav as *mut *mut u8,
                )
            };
            assert_eq!(expected, actual);
        }
        let end = std::time::Instant::now();
        println!("{}ms", end.duration_since(start).as_millis());

        unsafe {
            voicevox_wav_free(output_wav);
        }
    }

@Hiroshiba
Copy link
Member Author

問題を解決した製品版をリリースしました!
https://github.com/VOICEVOX/voicevox_core/releases/tag/0.14.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant