From ccbda99302accf2c06a3c9696d26a07c374a8e1a Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Tue, 10 Sep 2024 23:22:12 +0900 Subject: [PATCH] =?UTF-8?q?"join"=E3=81=97=E3=81=AA=E3=81=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/voicevox_core/src/infer/domains.rs | 9 --- crates/voicevox_core/src/voice_model.rs | 75 ++++++++++------------- 2 files changed, 33 insertions(+), 51 deletions(-) diff --git a/crates/voicevox_core/src/infer/domains.rs b/crates/voicevox_core/src/infer/domains.rs index 9cae20b2e..5225f2ec3 100644 --- a/crates/voicevox_core/src/infer/domains.rs +++ b/crates/voicevox_core/src/infer/domains.rs @@ -1,7 +1,5 @@ mod talk; -use std::future::Future; - use educe::Educe; use serde::{Deserialize, Deserializer}; @@ -40,13 +38,6 @@ impl InferenceDomainMap<(Result,)> { } } -impl InferenceDomainMap<(T,)> { - pub(crate) async fn join_all(self) -> InferenceDomainMap<(T::Output,)> { - let talk = self.talk.await; - InferenceDomainMap { talk } - } -} - impl<'de, V: InferenceDomainMapValues + ?Sized> Deserialize<'de> for InferenceDomainMap where V::Talk: Deserialize<'de>, diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 9d66c13a3..aed060633 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -14,7 +14,7 @@ use easy_ext::ext; use enum_map::enum_map; use enum_map::EnumMap; use futures_io::{AsyncBufRead, AsyncSeek}; -use futures_util::future::{FutureExt as _, OptionFuture, TryFutureExt as _}; +use futures_util::future::{OptionFuture, TryFutureExt as _}; use itertools::Itertools as _; use ouroboros::self_referencing; use serde::Deserialize; @@ -209,44 +209,36 @@ impl Inner { } self.with_inference_model_entries(|inference_model_entries| { - inference_model_entries - .each_ref() - .map(InferenceDomainMap { - talk: |talk| { - let talk = - talk.as_ref() - .map(|InferenceModelEntry { indices, manifest }| { - ( - indices.map(|op, i| (i, manifest[op].clone())), - manifest.style_id_to_inner_voice_id.clone(), - ) - }); - async { - OptionFuture::from(talk.map( - |(entries, style_id_to_inner_voice_id)| async { - let [predict_duration, predict_intonation, decode] = - entries.into_array(); - - let predict_duration = read_file!(predict_duration); - let predict_intonation = read_file!(predict_intonation); - let decode = read_file!(decode); - - let model_bytes = EnumMap::from_array([ - predict_duration, - predict_intonation, - decode, - ]); - - Ok((style_id_to_inner_voice_id, model_bytes)) - }, - )) - .await - .transpose() - } - }, - }) - .join_all() - .map(InferenceDomainMap::collect) + let talk = inference_model_entries.talk.as_ref().map( + |InferenceModelEntry { indices, manifest }| { + ( + indices.map(|op, i| (i, manifest[op].clone())), + manifest.style_id_to_inner_voice_id.clone(), + ) + }, + ); + + async { + let talk = async { + OptionFuture::from(talk.map(|(entries, style_id_to_inner_voice_id)| async { + let [predict_duration, predict_intonation, decode] = entries.into_array(); + + let predict_duration = read_file!(predict_duration); + let predict_intonation = read_file!(predict_intonation); + let decode = read_file!(decode); + + let model_bytes = + EnumMap::from_array([predict_duration, predict_intonation, decode]); + + Ok((style_id_to_inner_voice_id, model_bytes)) + })) + .await + .transpose() + } + .await?; + + Ok(InferenceDomainMap { talk }) + } }) .await } @@ -264,9 +256,8 @@ struct InferenceModelEntry { impl A { async fn open_zip( path: &Path, - ) -> anyhow::Result< - async_zip::base::read::seek::ZipFileReader, - > { + ) -> anyhow::Result> + { let zip = Self::open_file(path).await.with_context(|| { // fs-errのと同じにする format!("failed to open file `{}`", path.display())