diff --git a/.github/workflows/lint-and-test.yml b/.github/workflows/lint-and-test.yml index 238644e06..d012e43b7 100644 --- a/.github/workflows/lint-and-test.yml +++ b/.github/workflows/lint-and-test.yml @@ -234,7 +234,9 @@ jobs: with: version: nightly - name: Run tests - run: forge test --root=crates/proof-of-sql --summary --detailed + run: | + cargo run --bin yul_preprocessor crates/proof-of-sql + forge test --root=crates/proof-of-sql --summary --detailed solhint: name: solhint @@ -245,4 +247,4 @@ jobs: - name: Install solhint run: npm install -g solhint - name: Run tests - run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' -w 0 \ No newline at end of file + run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' 'crates/proof-of-sql/**/*.psol' -w 0 \ No newline at end of file diff --git a/.gitignore b/.gitignore index bd3a4853c..3d4e67f97 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,5 @@ cache # any output files from generating public params output/ + +*.p.sol \ No newline at end of file diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index 8772aa4ff..bb02f7cac 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -98,6 +98,11 @@ name = "commitment-utility" path = "utils/commitment-utility/main.rs" required-features = [ "std", "blitzar"] +[[bin]] +name = "yul_preprocessor" +path = "utils/yul-preprocessor/main.rs" +required-features = [ "std" ] + [[example]] name = "hello_world" required-features = ["test"] diff --git a/crates/proof-of-sql/sol_src/base/LagrangeBasisEvaluation.sol b/crates/proof-of-sql/sol_src/base/LagrangeBasisEvaluation.sol new file mode 100644 index 000000000..aa5c9b92c --- /dev/null +++ b/crates/proof-of-sql/sol_src/base/LagrangeBasisEvaluation.sol @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +contract LagrangeBasisEvaluation { + function computeTruncatedLagrangeBasisSum(uint256 length0, bytes memory point0, uint256 numVars0, uint256 modulus0) + public + pure + returns (uint256 result0) + { + // solhint-disable-next-line no-inline-assembly + assembly { + // START-YUL compute_truncated_lagrange_basis_sum + function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result { + let ONE := add(modulus, 1) + // result := 0 // implicitly set by the EVM + + // Invariant that holds within the for loop: + // 0 <= result <= modulus + 1 + // This invariant reduces modulus operations. + for {} num_vars {} { + switch and(length, 1) + case 0 { result := mulmod(result, sub(ONE, mod(mload(point), modulus)), modulus) } + default { result := sub(ONE, mulmod(sub(ONE, result), mload(point), modulus)) } + num_vars := sub(num_vars, 1) + length := shr(1, length) + point := add(point, 32) + } + switch length + case 0 { result := mod(result, modulus) } + default { result := 1 } + } + // END-YUL + result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0) + } + } + + uint256 private constant TEST_MODULUS = 10007; + + function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith0Variables() public pure { + bytes memory point = hex""; + assert(computeTruncatedLagrangeBasisSum(1, point, 0, TEST_MODULUS) == 1); + assert(computeTruncatedLagrangeBasisSum(0, point, 0, TEST_MODULUS) == 0); + } + + function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith1Variables() public pure { + bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002"; + assert(computeTruncatedLagrangeBasisSum(2, point, 1, TEST_MODULUS) == 1); + assert(computeTruncatedLagrangeBasisSum(1, point, 1, TEST_MODULUS) == TEST_MODULUS - 1); + assert(computeTruncatedLagrangeBasisSum(0, point, 1, TEST_MODULUS) == 0); + } + + function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith2Variables() public pure { + bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002" + hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005"; + assert(computeTruncatedLagrangeBasisSum(4, point, 2, TEST_MODULUS) == 1); + assert(computeTruncatedLagrangeBasisSum(3, point, 2, TEST_MODULUS) == TEST_MODULUS - 9); + assert(computeTruncatedLagrangeBasisSum(2, point, 2, TEST_MODULUS) == TEST_MODULUS - 4); + assert(computeTruncatedLagrangeBasisSum(1, point, 2, TEST_MODULUS) == 4); + assert(computeTruncatedLagrangeBasisSum(0, point, 2, TEST_MODULUS) == 0); + } + + function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith3Variables() public pure { + bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002" + hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005" + hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000007"; + assert(computeTruncatedLagrangeBasisSum(8, point, 3, TEST_MODULUS) == 1); + assert(computeTruncatedLagrangeBasisSum(7, point, 3, TEST_MODULUS) == TEST_MODULUS - 69); + assert(computeTruncatedLagrangeBasisSum(6, point, 3, TEST_MODULUS) == TEST_MODULUS - 34); + assert(computeTruncatedLagrangeBasisSum(5, point, 3, TEST_MODULUS) == 22); + assert(computeTruncatedLagrangeBasisSum(4, point, 3, TEST_MODULUS) == TEST_MODULUS - 6); + assert(computeTruncatedLagrangeBasisSum(3, point, 3, TEST_MODULUS) == 54); + assert(computeTruncatedLagrangeBasisSum(2, point, 3, TEST_MODULUS) == 24); + assert(computeTruncatedLagrangeBasisSum(1, point, 3, TEST_MODULUS) == TEST_MODULUS - 24); + assert(computeTruncatedLagrangeBasisSum(0, point, 3, TEST_MODULUS) == 0); + } +} diff --git a/crates/proof-of-sql/sol_src/tests/TestYulImport.t.psol b/crates/proof-of-sql/sol_src/tests/TestYulImport.t.psol new file mode 100644 index 000000000..9554cb3aa --- /dev/null +++ b/crates/proof-of-sql/sol_src/tests/TestYulImport.t.psol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +library TestScript { + function testWeCanImportYulFromAnotherFile() public pure { + bytes memory point0 = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" + hex"0000000000000002" hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" + hex"0000000000000005"; + uint256 length0 = 1; + uint256 numVars0 = 2; + uint256 modulus0 = 10007; + uint256 result0; + // solhint-disable-next-line no-inline-assembly + assembly { + // IMPORT-YUL ../base/LagrangeBasisEvaluation.sol:compute_truncated_lagrange_basis_sum + // solhint-disable-next-line no-empty-blocks + function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result {} + // END-IMPORT-YUL + result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0) + } + assert(result0 == 4); + } +} diff --git a/crates/proof-of-sql/src/base/polynomial/lagrange_basis_evaluation.rs b/crates/proof-of-sql/src/base/polynomial/lagrange_basis_evaluation.rs index be65d36a7..75c048611 100644 --- a/crates/proof-of-sql/src/base/polynomial/lagrange_basis_evaluation.rs +++ b/crates/proof-of-sql/src/base/polynomial/lagrange_basis_evaluation.rs @@ -74,6 +74,10 @@ where /// Given the point `point` (or `a`) with length nu, we can evaluate the lagrange basis of length 2^nu at that point. /// This is what [`super::compute_evaluation_vector`] does. +/// +/// NOTE: if length is greater than 2^nu, this function will pad `point` with 0s, which +/// will result in padding the basis with 0s. +/// /// Call the resulting evaluation vector A. This function computes `sum A[i] for i in 0..length`. That is: /// ```text /// (1-a[0])(1-a[1])...(1-a[nu-1]) + @@ -81,40 +85,22 @@ where /// (1-a[0])(a[1])...(1-a[nu-1]) + /// (a[0])(a[1])...(1-a[nu-1]) + ... /// ``` -/// # Panics -/// Panics if: -/// - The length is greater than `1` when `point` is empty. -/// - The length is greater than the maximum allowed for the given number of points, which is `2^(nu - 1)` -/// where `nu` is the number of elements in `point`. pub fn compute_truncated_lagrange_basis_sum(length: usize, point: &[F]) -> F where - F: One + Zero + Mul + Add + Sub + Copy, + F: One + Zero + Mul + Sub + Copy, { - let nu = point.len(); - if nu == 0 { - assert!(length <= 1); - if length == 1 { - F::one() - } else { - F::zero() - } + if length >= 1 << point.len() { + F::one() } else { - // Note: this is essentially the same as the inner production version. - // The only different is that the full sum is always 1, regardless of any inputs. - - let first_half_term = F::one() - point[nu - 1]; - let second_half_term = point[nu - 1]; - let half_full_length = 1 << (nu - 1); - let sub_part_length = if length >= half_full_length { - length - half_full_length - } else { - length - }; - let sub_part = compute_truncated_lagrange_basis_sum(sub_part_length, &point[..nu - 1]); - if length >= half_full_length { - first_half_term + sub_part * second_half_term - } else { - sub_part * first_half_term - } + point + .iter() + .enumerate() + .fold(F::zero(), |chi, (i, &alpha)| { + if (length >> i) & 1 == 0 { + chi * (F::one() - alpha) + } else { + F::one() - (F::one() - chi) * alpha + } + }) } } diff --git a/crates/proof-of-sql/utils/yul-preprocessor/main.rs b/crates/proof-of-sql/utils/yul-preprocessor/main.rs new file mode 100644 index 000000000..cb3bb5f97 --- /dev/null +++ b/crates/proof-of-sql/utils/yul-preprocessor/main.rs @@ -0,0 +1,171 @@ +//! This binary applies a preprocessing step to Solidity files that allows for importing Yul code from other files. + +use clap::Parser; +use snafu::Snafu; +use std::{ + fs::{self, File}, + io::{self, BufRead, BufWriter, Write}, + path::Path, +}; + +const IMPORT_YUL: &str = "// IMPORT-YUL"; +const END_IMPORT_YUL: &str = "// END-IMPORT-YUL"; +const START_YUL: &str = "// START-YUL"; +const END_YUL: &str = "// END-YUL"; +const IMPORTED_YUL: &str = "// IMPORTED-YUL"; +const END_IMPORTED_YUL: &str = "// END-IMPORTED-YUL"; + +#[derive(Debug, Snafu)] +enum Error { + #[snafu(transparent)] + Io { source: io::Error }, + #[snafu(display("Ill-formed IMPORT-YUL statement at line {line}"))] + IllFormedImportYul { line: usize }, + #[snafu(display("Unmatched END-IMPORT-YUL at line {line}"))] + UnmatchedEndImportYul { line: usize }, + #[snafu(display("Unmatched IMPORT-YUL at line"))] + UnmatchedImportYul, + #[snafu(display("Function {function_name} not found in file {file_path}"))] + FunctionNotFound { + function_name: String, + file_path: String, + }, +} + +/// A preprocessor for Solidity files to import Yul code +/// +/// This tool processes a given file or directory, replacing the import statements with the corresponding Yul code. +/// +/// # Usage +/// +/// The Yul code should be wrapped in `// START-YUL ` and `// END-YUL` comments in the source files. +/// The import statement should be in the form `// IMPORT-YUL :` in the target Solidity files. +/// +/// # Example +/// +/// Given a Solidity file `example.psol` with the following content: +/// +/// ```solidity +/// // IMPORT-YUL yul_code.sol:my_function +/// // END-IMPORT-YUL +/// ``` +/// +/// And a Yul file `yul_code.sol` with the following content: +/// +/// ```solidity +/// // START-YUL my_function +/// function my_function() -> result { +/// // Yul code here +/// } +/// // END-YUL +/// ``` +/// +/// Running the binary will produce an output file `example.p.sol` with the following content: +/// +/// ```solidity +/// // IMPORTED-YUL yul_code.sol:my_function +/// function my_function() -> result { +/// // Yul code here +/// } +/// // END-IMPORTED-YUL +/// ``` +#[derive(Parser, Debug)] +#[command(about, long_about)] +struct Args { + /// The path to the file or directory to process + path: String, +} + +fn main() -> Result<(), Error> { + let args = Args::parse(); + process_path(Path::new(&args.path))?; + Ok(()) +} + +fn process_path(path: &Path) -> Result<(), Error> { + if path.is_dir() { + for entry in fs::read_dir(path)? { + process_path(&entry?.path())?; + } + } else if path.extension().and_then(|ext| ext.to_str()) == Some("psol") { + process_file(path)?; + } + Ok(()) +} + +fn process_file(path: &Path) -> Result<(), Error> { + let file = File::open(path)?; + let reader = io::BufReader::new(file); + let mut output_lines = Vec::new(); + let mut inside_import = false; + let mut import_file = String::new(); + let mut function_name = String::new(); + let base_path = path.parent().unwrap_or_else(|| Path::new("")); + + for (line_number, line) in reader.lines().enumerate() { + let line = line?; + if let Some(import_pos) = line.find(IMPORT_YUL) { + let parts: Vec<&str> = line[import_pos + IMPORT_YUL.len()..].split(':').collect(); + if parts.len() != 2 { + return Err(Error::IllFormedImportYul { + line: line_number + 1, + }); + } + inside_import = true; + import_file = parts[0].trim().to_string(); + function_name = parts[1].trim().to_string(); + } else if line.contains(END_IMPORT_YUL) { + if !inside_import { + return Err(Error::UnmatchedEndImportYul { + line: line_number + 1, + }); + } + let function_lines = extract_function(&base_path.join(&import_file), &function_name)?; + output_lines.push(format!("{IMPORTED_YUL} {import_file}:{function_name}")); + output_lines.extend(function_lines); + output_lines.push(END_IMPORTED_YUL.to_string()); + inside_import = false; + } else if !inside_import { + output_lines.push(line); + } + } + + if inside_import { + return Err(Error::UnmatchedImportYul); + } + + let file = File::create(path.with_extension("p.sol"))?; + let mut writer = BufWriter::new(file); + for line in output_lines { + writeln!(writer, "{line}")?; + } + + Ok(()) +} + +fn extract_function(file_path: &Path, function_name: &str) -> Result, Error> { + let file = File::open(file_path)?; + let reader = io::BufReader::new(file); + let mut function_lines = Vec::new(); + let mut inside_function = false; + + for line in reader.lines() { + let line = line?; + if line.contains(&format!("{START_YUL} {function_name}")) { + inside_function = true; + } else if line.contains(END_YUL) { + break; + } else if inside_function { + function_lines.push(line); + } + } + + if !inside_function { + return Err(Error::FunctionNotFound { + function_name: function_name.to_string(), + file_path: file_path.display().to_string(), + }); + } + + Ok(function_lines) +}