Skip to content

Commit

Permalink
[project-s] ピッチ輪郭推論を追加 (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hiroshiba committed Jan 27, 2024
1 parent 1895ab5 commit c7dfb03
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 7 deletions.
3 changes: 3 additions & 0 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ pub enum Error {
)]
InvalidModelIndex { model_index: usize },

#[error("{}", base_error_message(VOICEVOX_RESULT_UNSUPPORTED_MODEL_ERROR))]
UnsupportedModel,

#[error("{}", base_error_message(VOICEVOX_RESULT_INFERENCE_ERROR))]
InferenceFailed,

Expand Down
117 changes: 112 additions & 5 deletions crates/voicevox_core/src/publish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ impl VoicevoxCore {
)
}

pub fn predict_contour(
&mut self,
length: usize,
f0_discrete: &[f32],
phoneme_vector: &[i64],
speaker_id: u32,
) -> Result<Vec<f32>> {
self.synthesis_engine.inference_core_mut().predict_contour(
length,
f0_discrete,
phoneme_vector,
speaker_id,
)
}

pub fn decode(
&mut self,
length: usize,
Expand Down Expand Up @@ -371,15 +386,15 @@ impl InferenceCore {
let input_tensors: Vec<&mut dyn AnyArray> =
vec![&mut phoneme_vector_array, &mut speaker_id_array];

let mut output = status.predict_duration_session_run(model_index, input_tensors)?;
let mut duration = status.predict_duration_session_run(model_index, input_tensors)?;

for output_item in output.iter_mut() {
if *output_item < PHONEME_LENGTH_MINIMAL {
*output_item = PHONEME_LENGTH_MINIMAL;
for duration_item in duration.iter_mut() {
if *duration_item < PHONEME_LENGTH_MINIMAL {
*duration_item = PHONEME_LENGTH_MINIMAL;
}
}

Ok(output)
Ok(duration)
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -444,6 +459,59 @@ impl InferenceCore {
status.predict_intonation_session_run(model_index, input_tensors)
}

pub fn predict_contour(
&mut self,
length: usize,
f0_discrete: &[f32],
phoneme_vector: &[i64],
speaker_id: u32,
) -> Result<Vec<f32>> {
if !self.initialized {
return Err(Error::UninitializedStatus);
}

let status = self
.status_option
.as_mut()
.ok_or(Error::UninitializedStatus)?;

if !status.validate_speaker_id(speaker_id) {
return Err(Error::InvalidSpeakerId { speaker_id });
}

let (model_index, speaker_id) =
if let Some((model_index, speaker_id)) = get_model_index_and_speaker_id(speaker_id) {
(model_index, speaker_id)
} else {
return Err(Error::InvalidSpeakerId { speaker_id });
};

if model_index >= MODEL_FILE_SET.models_count() {
return Err(Error::InvalidModelIndex { model_index });
}

let mut f0_discrete_array =
NdArray::new(ndarray::arr1(f0_discrete).into_shape([length, 1]).unwrap());
let mut phoneme_vector_array = NdArray::new(ndarray::arr1(phoneme_vector));
let mut speaker_id_array = NdArray::new(ndarray::arr1(&[speaker_id as i64]));

let input_tensors: Vec<&mut dyn AnyArray> = vec![
&mut f0_discrete_array,
&mut phoneme_vector_array,
&mut speaker_id_array,
];

let (mut f0_contour, voiced) =
status.predict_contour_session_run(model_index, input_tensors)?;
for (f0_contour_item, voiced_item) in f0_contour.iter_mut().zip(voiced.iter()) {
if *voiced_item < 0.0 {
*f0_contour_item = 0.0;
}
}

Ok(f0_contour)
}

pub fn decode(
&mut self,
length: usize,
Expand Down Expand Up @@ -598,6 +666,7 @@ pub const fn error_result_to_message(result_code: VoicevoxResultCode) -> &'stati
VOICEVOX_RESULT_UNINITIALIZED_STATUS_ERROR => "Statusが初期化されていません\0",
VOICEVOX_RESULT_INVALID_SPEAKER_ID_ERROR => "無効なspeaker_idです\0",
VOICEVOX_RESULT_INVALID_MODEL_INDEX_ERROR => "無効なmodel_indexです\0",
VOICEVOX_RESULT_UNSUPPORTED_MODEL_ERROR => "未対応なモデルです\0",
VOICEVOX_RESULT_INFERENCE_ERROR => "推論に失敗しました\0",
VOICEVOX_RESULT_EXTRACT_FULL_CONTEXT_LABEL_ERROR => {
"入力テキストからのフルコンテキストラベル抽出に失敗しました\0"
Expand Down Expand Up @@ -852,6 +921,44 @@ mod tests {
assert_eq!(result.unwrap().len(), vowel_phoneme_vector.len());
}

#[rstest]
fn predict_contour_works() {
let internal = VoicevoxCore::new_with_mutex();
internal
.lock()
.unwrap()
.initialize(InitializeOptions {
load_all_models: true,
acceleration_mode: AccelerationMode::Cpu,
..Default::default()
})
.unwrap();

// 「テスト」という文章に対応する入力
const F0_LENGTH: usize = 69;
let mut f0_discrete = [0.; F0_LENGTH];
f0_discrete[9..24].fill(5.905218);
f0_discrete[37..60].fill(5.565851);

let mut phoneme = [0; F0_LENGTH];
phoneme[0..9].fill(0);
phoneme[9..13].fill(37);
phoneme[13..24].fill(14);
phoneme[24..30].fill(35);
phoneme[30..37].fill(6);
phoneme[37..45].fill(37);
phoneme[45..60].fill(30);
phoneme[60..69].fill(0);

let result = internal
.lock()
.unwrap()
.predict_contour(F0_LENGTH, &f0_discrete, &phoneme, 2);

assert!(result.is_ok(), "{result:?}");
assert_eq!(result.unwrap().len(), F0_LENGTH);
}

#[rstest]
fn decode_works() {
let internal = VoicevoxCore::new_with_mutex();
Expand Down
2 changes: 2 additions & 0 deletions crates/voicevox_core/src/result_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub enum VoicevoxResultCode {
VOICEVOX_RESULT_INVALID_SPEAKER_ID_ERROR = 7,
/// 無効なmodel_indexが指定された
VOICEVOX_RESULT_INVALID_MODEL_INDEX_ERROR = 8,
/// 対応していないmodelが指定された
VOICEVOX_RESULT_UNSUPPORTED_MODEL_ERROR = 15,
/// 推論に失敗した
VOICEVOX_RESULT_INFERENCE_ERROR = 9,
/// コンテキストラベル出力に失敗した
Expand Down
47 changes: 45 additions & 2 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct Status {
struct StatusModels {
predict_duration: BTreeMap<usize, Session<'static>>,
predict_intonation: BTreeMap<usize, Session<'static>>,
predict_contour: BTreeMap<usize, Option<Session<'static>>>,
decode: BTreeMap<usize, Session<'static>>,
}

Expand Down Expand Up @@ -82,14 +83,19 @@ impl ModelFileSet {
|&ModelFileNames {
predict_duration_model,
predict_intonation_model,
predict_contour_model,
decode_model,
}| {
let predict_duration_model = ModelFile::new(&path(predict_duration_model))?;
let predict_intonation_model = ModelFile::new(&path(predict_intonation_model))?;
let predict_contour_model = predict_contour_model
.map(|s| ModelFile::new(&path(s)))
.transpose()?;
let decode_model = ModelFile::new(&path(decode_model))?;
Ok(Model {
predict_duration_model,
predict_intonation_model,
predict_contour_model,
decode_model,
})
},
Expand All @@ -113,6 +119,7 @@ impl ModelFileSet {
struct ModelFileNames {
predict_duration_model: &'static str,
predict_intonation_model: &'static str,
predict_contour_model: Option<&'static str>,
decode_model: &'static str,
}

Expand All @@ -123,6 +130,7 @@ struct DecryptModelError;
struct Model {
predict_duration_model: ModelFile,
predict_intonation_model: ModelFile,
predict_contour_model: Option<ModelFile>,
decode_model: ModelFile,
}

Expand Down Expand Up @@ -208,6 +216,7 @@ impl Status {
models: StatusModels {
predict_duration: BTreeMap::new(),
predict_intonation: BTreeMap::new(),
predict_contour: BTreeMap::new(),
decode: BTreeMap::new(),
},
light_session_options: SessionOptions::new(cpu_num_threads, false),
Expand Down Expand Up @@ -236,6 +245,11 @@ impl Status {
self.new_session(&model.predict_duration_model, &self.light_session_options)?;
let predict_intonation_session =
self.new_session(&model.predict_intonation_model, &self.light_session_options)?;
let predict_contour_session = if let Some(model) = &model.predict_contour_model {
Some(self.new_session(model, &self.light_session_options)?)
} else {
None
};
let decode_model =
self.new_session(&model.decode_model, &self.heavy_session_options)?;

Expand All @@ -245,6 +259,9 @@ impl Status {
self.models
.predict_intonation
.insert(model_index, predict_intonation_session);
self.models
.predict_contour
.insert(model_index, predict_contour_session);

self.models.decode.insert(model_index, decode_model);

Expand All @@ -255,8 +272,9 @@ impl Status {
}

pub fn is_model_loaded(&self, model_index: usize) -> bool {
self.models.predict_intonation.contains_key(&model_index)
&& self.models.predict_duration.contains_key(&model_index)
self.models.predict_duration.contains_key(&model_index)
&& self.models.predict_intonation.contains_key(&model_index)
&& self.models.predict_contour.contains_key(&model_index)
&& self.models.decode.contains_key(&model_index)
}

Expand Down Expand Up @@ -338,6 +356,29 @@ impl Status {
}
}

pub fn predict_contour_session_run(
&mut self,
model_index: usize,
inputs: Vec<&mut dyn AnyArray>,
) -> Result<(Vec<f32>, Vec<f32>)> {
if let Some(model) = self.models.predict_contour.get_mut(&model_index) {
if let Some(model) = model {
if let Ok(output_tensors) = model.run(inputs) {
Ok((
output_tensors[0].as_slice().unwrap().to_owned(),
output_tensors[1].as_slice().unwrap().to_owned(),
))
} else {
Err(Error::InferenceFailed)
}
} else {
Err(Error::UnsupportedModel)
}
} else {
Err(Error::InvalidModelIndex { model_index })
}
}

pub fn decode_session_run(
&mut self,
model_index: usize,
Expand Down Expand Up @@ -383,6 +424,7 @@ mod tests {
);
assert!(status.models.predict_duration.is_empty());
assert!(status.models.predict_intonation.is_empty());
assert!(status.models.predict_contour.is_empty());
assert!(status.models.decode.is_empty());
assert!(status.supported_styles.is_empty());
}
Expand Down Expand Up @@ -410,6 +452,7 @@ mod tests {
assert_eq!(Ok(()), result);
assert_eq!(1, status.models.predict_duration.len());
assert_eq!(1, status.models.predict_intonation.len());
assert_eq!(1, status.models.predict_contour.len());
assert_eq!(1, status.models.decode.len());
}

Expand Down
2 changes: 2 additions & 0 deletions crates/voicevox_core/src/status/model_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ pub(super) const MODEL_FILE_NAMES: &[ModelFileNames] = &[
ModelFileNames {
predict_duration_model: "predict_duration-0.onnx",
predict_intonation_model: "predict_intonation-0.onnx",
predict_contour_model: None,
decode_model: "decode-0.onnx",
},
ModelFileNames {
predict_duration_model: "predict_duration-1.onnx",
predict_intonation_model: "predict_intonation-1.onnx",
predict_contour_model: Some("predict_contour-1.onnx"),
decode_model: "decode-1.onnx",
},
];
27 changes: 27 additions & 0 deletions crates/voicevox_core_c_api/src/compatible_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,33 @@ pub extern "C" fn yukarin_sa_forward(
}
}

#[no_mangle]
pub extern "C" fn yukarin_sosf_forward(
length: i64,
f0_discrete: *mut f32,
phoneme: *mut i64,
speaker_id: *mut i64,
output: *mut f32,
) -> bool {
let result = lock_internal().predict_contour(
length as usize,
unsafe { std::slice::from_raw_parts(f0_discrete, length as usize) },
unsafe { std::slice::from_raw_parts(phoneme, length as usize) },
unsafe { *speaker_id as u32 },
);
match result {
Ok(output_vec) => {
let output_slice = unsafe { std::slice::from_raw_parts_mut(output, length as usize) };
output_slice.clone_from_slice(&output_vec);
true
}
Err(err) => {
set_message(&format!("{err}"));
false
}
}
}

#[no_mangle]
pub extern "C" fn decode_forward(
length: i64,
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core_c_api/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub(crate) fn into_result_code_with_error(result: CApiResult<()>) -> VoicevoxRes
Err(RustApi(UninitializedStatus)) => VOICEVOX_RESULT_UNINITIALIZED_STATUS_ERROR,
Err(RustApi(InvalidSpeakerId { .. })) => VOICEVOX_RESULT_INVALID_SPEAKER_ID_ERROR,
Err(RustApi(InvalidModelIndex { .. })) => VOICEVOX_RESULT_INVALID_MODEL_INDEX_ERROR,
Err(RustApi(UnsupportedModel { .. })) => VOICEVOX_RESULT_UNSUPPORTED_MODEL_ERROR,
Err(RustApi(InferenceFailed)) => VOICEVOX_RESULT_INFERENCE_ERROR,
Err(RustApi(ExtractFullContextLabel(_))) => {
VOICEVOX_RESULT_EXTRACT_FULL_CONTEXT_LABEL_ERROR
Expand Down
Binary file added model/predict_contour-1.onnx
Binary file not shown.

0 comments on commit c7dfb03

Please sign in to comment.