diff --git a/constraint-evaluation-generator/src/substitution.rs b/constraint-evaluation-generator/src/substitution.rs index 83279893..b7cce408 100644 --- a/constraint-evaluation-generator/src/substitution.rs +++ b/constraint-evaluation-generator/src/substitution.rs @@ -66,9 +66,12 @@ impl AllSubstitutions { //! To re-generate, execute: //! `cargo run --bin constraint-evaluation-generator` + use ndarray::Array1; use ndarray::s; use ndarray::ArrayView2; use ndarray::ArrayViewMut2; + use ndarray::Axis; + use ndarray::Zip; use strum::Display; use strum::EnumCount; use strum::EnumIter; @@ -118,32 +121,19 @@ impl Substitutions { let derived_section_tran_start = derived_section_cons_start + self.cons.len(); let derived_section_term_start = derived_section_tran_start + self.tran.len(); - let init_col_indices = (0..self.init.len()) - .map(|i| i + derived_section_init_start) - .collect_vec(); - let cons_col_indices = (0..self.cons.len()) - .map(|i| i + derived_section_cons_start) - .collect_vec(); - let tran_col_indices = (0..self.tran.len()) - .map(|i| i + derived_section_tran_start) - .collect_vec(); - let term_col_indices = (0..self.term.len()) - .map(|i| i + derived_section_term_start) - .collect_vec(); - let init_substitutions = Self::several_substitution_rules_to_code(&self.init); let cons_substitutions = Self::several_substitution_rules_to_code(&self.cons); let tran_substitutions = Self::several_substitution_rules_to_code(&self.tran); let term_substitutions = Self::several_substitution_rules_to_code(&self.term); let init_substitutions = - Self::base_single_row_substitutions(&init_col_indices, &init_substitutions); + Self::base_single_row_substitutions(derived_section_init_start, &init_substitutions); let cons_substitutions = - Self::base_single_row_substitutions(&cons_col_indices, &cons_substitutions); + Self::base_single_row_substitutions(derived_section_cons_start, &cons_substitutions); let tran_substitutions = - Self::base_dual_row_substitutions(&tran_col_indices, &tran_substitutions); + Self::base_dual_row_substitutions(derived_section_tran_start, &tran_substitutions); let term_substitutions = - Self::base_single_row_substitutions(&term_col_indices, &term_substitutions); + Self::base_single_row_substitutions(derived_section_term_start, &term_substitutions); quote!( #[allow(unused_variables)] @@ -163,32 +153,19 @@ impl Substitutions { let derived_section_tran_start = derived_section_cons_start + self.cons.len(); let derived_section_term_start = derived_section_tran_start + self.tran.len(); - let init_col_indices = (0..self.init.len()) - .map(|i| i + derived_section_init_start) - .collect_vec(); - let cons_col_indices = (0..self.cons.len()) - .map(|i| i + derived_section_cons_start) - .collect_vec(); - let tran_col_indices = (0..self.tran.len()) - .map(|i| i + derived_section_tran_start) - .collect_vec(); - let term_col_indices = (0..self.term.len()) - .map(|i| i + derived_section_term_start) - .collect_vec(); - let init_substitutions = Self::several_substitution_rules_to_code(&self.init); let cons_substitutions = Self::several_substitution_rules_to_code(&self.cons); let tran_substitutions = Self::several_substitution_rules_to_code(&self.tran); let term_substitutions = Self::several_substitution_rules_to_code(&self.term); let init_substitutions = - Self::ext_single_row_substitutions(&init_col_indices, &init_substitutions); + Self::ext_single_row_substitutions(derived_section_init_start, &init_substitutions); let cons_substitutions = - Self::ext_single_row_substitutions(&cons_col_indices, &cons_substitutions); + Self::ext_single_row_substitutions(derived_section_cons_start, &cons_substitutions); let tran_substitutions = - Self::ext_dual_row_substitutions(&tran_col_indices, &tran_substitutions); + Self::ext_dual_row_substitutions(derived_section_tran_start, &tran_substitutions); let term_substitutions = - Self::ext_single_row_substitutions(&term_col_indices, &term_substitutions); + Self::ext_single_row_substitutions(derived_section_term_start, &term_substitutions); quote!( #[allow(unused_variables)] @@ -241,93 +218,133 @@ impl Substitutions { } fn base_single_row_substitutions( - indices: &[usize], + section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { - assert_eq!(indices.len(), substitutions.len()); + let num_substitutions = substitutions.len(); + let indices = (0..num_substitutions).collect_vec(); if indices.is_empty() { return quote!(); } quote!( - master_base_table.rows_mut().into_iter().for_each(|mut row| { - #( - let (base_row, mut det_col) = - row.multi_slice_mut((s![..#indices],s![#indices..=#indices])); - det_col[0] = #substitutions; - )* - }); + let (original_part, mut current_section) = + master_base_table.multi_slice_mut( + ( + s![.., 0..#section_start_index], + s![.., #section_start_index..#section_start_index+#num_substitutions], + ) + ); + Zip::from(original_part.rows()) + .and(current_section.rows_mut()) + .par_for_each(|original_row, mut section_row| { + let mut base_row = original_row.to_owned(); + #( + section_row[#indices] = #substitutions; + base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); + )* + }); ) } fn base_dual_row_substitutions( - indices: &[usize], + section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { - assert_eq!(indices.len(), substitutions.len()); + let num_substitutions = substitutions.len(); + let indices = (0..substitutions.len()).collect_vec(); if indices.is_empty() { return quote!(); } quote!( - for curr_row_idx in 0..master_base_table.nrows() - 1 { - let next_row_idx = curr_row_idx + 1; - let (mut curr_base_row, next_base_row) = master_base_table.multi_slice_mut(( - s![curr_row_idx..=curr_row_idx, ..], - s![next_row_idx..=next_row_idx, ..], - )); - let mut curr_base_row = curr_base_row.row_mut(0); - let next_base_row = next_base_row.row(0); - #( - let (current_base_row, mut det_col) = - curr_base_row.multi_slice_mut((s![..#indices], s![#indices..=#indices])); - det_col[0] = #substitutions; - )* - } + let num_rows = master_base_table.nrows(); + let (original_part, mut current_section) = + master_base_table.multi_slice_mut( + ( + s![.., 0..#section_start_index], + s![.., #section_start_index..#section_start_index+#num_substitutions], + ) + ); + let row_indices = Array1::from_vec((0..num_rows - 1).collect::>()); + Zip::from(current_section.slice_mut(s![0..num_rows-1, ..]).rows_mut()) + .and(row_indices.view()) + .par_for_each( |mut section_row, ¤t_row_index| { + let next_row_index = current_row_index + 1; + let current_base_row_slice = original_part.slice(s![current_row_index..=current_row_index, ..]); + let next_base_row_slice = original_part.slice(s![next_row_index..=next_row_index, ..]); + let mut current_base_row = current_base_row_slice.row(0).to_owned(); + let next_base_row = next_base_row_slice.row(0); + #( + section_row[#indices] = #substitutions; + current_base_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); + )* + }); ) } fn ext_single_row_substitutions( - indices: &[usize], + section_start_index: usize, substitutions: &[TokenStream], ) -> TokenStream { - assert_eq!(indices.len(), substitutions.len()); + let num_substitutions = substitutions.len(); + let indices = (0..substitutions.len()).collect_vec(); if indices.is_empty() { return quote!(); } quote!( - for row_idx in 0..master_base_table.nrows() - 1 { - let base_row = master_base_table.row(row_idx); - let mut extension_row = master_ext_table.row_mut(row_idx); - #( - let (ext_row, mut det_col) = - extension_row.multi_slice_mut((s![..#indices],s![#indices..=#indices])); - det_col[0] = #substitutions; - )* - } + let (original_part, mut current_section) = master_ext_table.multi_slice_mut( + ( + s![.., 0..#section_start_index], + s![.., #section_start_index..#section_start_index+#num_substitutions], + ) + ); + Zip::from(master_base_table.rows()) + .and(original_part.rows()) + .and(current_section.rows_mut()) + .par_for_each( + |base_table_row, original_row, mut section_row| { + let mut extension_row = original_row.to_owned(); + #( + let (original_row_extension_row, mut det_col) = + section_row.multi_slice_mut((s![..#indices],s![#indices..=#indices])); + det_col[0] = #substitutions; + extension_row.push(Axis(0), det_col.slice(s![0])).unwrap(); + )* + } + ); ) } - fn ext_dual_row_substitutions(indices: &[usize], substitutions: &[TokenStream]) -> TokenStream { - assert_eq!(indices.len(), substitutions.len()); + fn ext_dual_row_substitutions( + section_start_index: usize, + substitutions: &[TokenStream], + ) -> TokenStream { + let num_substitutions = substitutions.len(); + let indices = (0..substitutions.len()).collect_vec(); if indices.is_empty() { return quote!(); } quote!( - for curr_row_idx in 0..master_base_table.nrows() - 1 { - let next_row_idx = curr_row_idx + 1; - let current_base_row = master_base_table.row(curr_row_idx); - let next_base_row = master_base_table.row(next_row_idx); - let (mut curr_ext_row, next_ext_row) = master_ext_table.multi_slice_mut(( - s![curr_row_idx..=curr_row_idx, ..], - s![next_row_idx..=next_row_idx, ..], - )); - let mut curr_ext_row = curr_ext_row.row_mut(0); - let next_ext_row = next_ext_row.row(0); - #( - let (current_ext_row, mut det_col) = - curr_ext_row.multi_slice_mut((s![..#indices], s![#indices..=#indices])); - det_col[0] = #substitutions; - )* - } + let num_rows = master_base_table.nrows(); + let (original_part, mut current_section) = master_ext_table.multi_slice_mut( + ( + s![.., 0..#section_start_index], + s![.., #section_start_index..#section_start_index+#num_substitutions], + ) + ); + let row_indices = Array1::from_vec((0..num_rows - 1).collect::>()); + Zip::from(current_section.slice_mut(s![0..num_rows-1, ..]).rows_mut()) + .and(row_indices.view()) + .par_for_each(|mut section_row, ¤t_row_index| { + let next_row_index = current_row_index + 1; + let current_base_row = master_base_table.row(current_row_index); + let next_base_row = master_base_table.row(next_row_index); + let mut current_ext_row = original_part.row(current_row_index).to_owned(); + let next_ext_row = original_part.row(next_row_index); + #( + section_row[#indices]= #substitutions; + current_ext_row.push(Axis(0), section_row.slice(s![#indices])).unwrap(); + )* + }); ) } } diff --git a/triton-vm/src/table/master_table.rs b/triton-vm/src/table/master_table.rs index 232fd414..df0e9a13 100644 --- a/triton-vm/src/table/master_table.rs +++ b/triton-vm/src/table/master_table.rs @@ -875,6 +875,7 @@ impl MasterBaseTable { u32_table, ]; + profiler!(start "pad original tables"); Self::all_pad_functions() .into_par_iter() .zip_eq(base_tables.into_par_iter()) @@ -882,8 +883,11 @@ impl MasterBaseTable { .for_each(|((pad, base_table), table_length)| { pad(base_table, table_length); }); + profiler!(stop "pad original tables"); + profiler!(start "fill degree-lowering table"); DegreeLoweringTable::fill_derived_base_columns(self.trace_table_mut()); + profiler!(stop "fill degree-lowering table"); } fn all_pad_functions() -> [PadFunction; NUM_TABLES_WITHOUT_DEGREE_LOWERING] {