Skip to content

Commit

Permalink
feat: yul lagrange basis evaluation (#442)
Browse files Browse the repository at this point in the history
# Rationale for this change

In order to port the verifier to Solidity, we need to write some code in
Yul.

# What changes are included in this PR?

Simplified the existing Rust lagrange basis evaluation.
Added a YUL port.
Added some tooling around YUL code.

# Are these changes tested?
Yes
  • Loading branch information
iajoiner authored Dec 18, 2024
2 parents c83b298 + ce07269 commit 792aaf3
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 33 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ jobs:
with:
version: nightly
- name: Run tests
run: forge test --root=crates/proof-of-sql --summary --detailed
run: |
cargo run --bin yul_preprocessor crates/proof-of-sql
forge test --root=crates/proof-of-sql --summary --detailed
solhint:
name: solhint
Expand All @@ -245,4 +247,4 @@ jobs:
- name: Install solhint
run: npm install -g solhint
- name: Run tests
run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' -w 0
run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' 'crates/proof-of-sql/**/*.psol' -w 0
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ cache

# any output files from generating public params
output/

*.p.sol
5 changes: 5 additions & 0 deletions crates/proof-of-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ name = "commitment-utility"
path = "utils/commitment-utility/main.rs"
required-features = [ "std", "blitzar"]

[[bin]]
name = "yul_preprocessor"
path = "utils/yul-preprocessor/main.rs"
required-features = [ "std" ]

[[example]]
name = "hello_world"
required-features = ["test"]
Expand Down
76 changes: 76 additions & 0 deletions crates/proof-of-sql/sol_src/base/LagrangeBasisEvaluation.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

contract LagrangeBasisEvaluation {
function computeTruncatedLagrangeBasisSum(uint256 length0, bytes memory point0, uint256 numVars0, uint256 modulus0)
public
pure
returns (uint256 result0)
{
// solhint-disable-next-line no-inline-assembly
assembly {
// START-YUL compute_truncated_lagrange_basis_sum
function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result {
let ONE := add(modulus, 1)
// result := 0 // implicitly set by the EVM

// Invariant that holds within the for loop:
// 0 <= result <= modulus + 1
// This invariant reduces modulus operations.
for {} num_vars {} {
switch and(length, 1)
case 0 { result := mulmod(result, sub(ONE, mod(mload(point), modulus)), modulus) }
default { result := sub(ONE, mulmod(sub(ONE, result), mload(point), modulus)) }
num_vars := sub(num_vars, 1)
length := shr(1, length)
point := add(point, 32)
}
switch length
case 0 { result := mod(result, modulus) }
default { result := 1 }
}
// END-YUL
result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0)
}
}

uint256 private constant TEST_MODULUS = 10007;

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith0Variables() public pure {
bytes memory point = hex"";
assert(computeTruncatedLagrangeBasisSum(1, point, 0, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(0, point, 0, TEST_MODULUS) == 0);
}

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith1Variables() public pure {
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002";
assert(computeTruncatedLagrangeBasisSum(2, point, 1, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(1, point, 1, TEST_MODULUS) == TEST_MODULUS - 1);
assert(computeTruncatedLagrangeBasisSum(0, point, 1, TEST_MODULUS) == 0);
}

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith2Variables() public pure {
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002"
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005";
assert(computeTruncatedLagrangeBasisSum(4, point, 2, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(3, point, 2, TEST_MODULUS) == TEST_MODULUS - 9);
assert(computeTruncatedLagrangeBasisSum(2, point, 2, TEST_MODULUS) == TEST_MODULUS - 4);
assert(computeTruncatedLagrangeBasisSum(1, point, 2, TEST_MODULUS) == 4);
assert(computeTruncatedLagrangeBasisSum(0, point, 2, TEST_MODULUS) == 0);
}

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith3Variables() public pure {
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002"
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005"
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000007";
assert(computeTruncatedLagrangeBasisSum(8, point, 3, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(7, point, 3, TEST_MODULUS) == TEST_MODULUS - 69);
assert(computeTruncatedLagrangeBasisSum(6, point, 3, TEST_MODULUS) == TEST_MODULUS - 34);
assert(computeTruncatedLagrangeBasisSum(5, point, 3, TEST_MODULUS) == 22);
assert(computeTruncatedLagrangeBasisSum(4, point, 3, TEST_MODULUS) == TEST_MODULUS - 6);
assert(computeTruncatedLagrangeBasisSum(3, point, 3, TEST_MODULUS) == 54);
assert(computeTruncatedLagrangeBasisSum(2, point, 3, TEST_MODULUS) == 24);
assert(computeTruncatedLagrangeBasisSum(1, point, 3, TEST_MODULUS) == TEST_MODULUS - 24);
assert(computeTruncatedLagrangeBasisSum(0, point, 3, TEST_MODULUS) == 0);
}
}
23 changes: 23 additions & 0 deletions crates/proof-of-sql/sol_src/tests/TestYulImport.t.psol
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

library TestScript {
function testWeCanImportYulFromAnotherFile() public pure {
bytes memory point0 = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000"
hex"0000000000000002" hex"0000000000000000" hex"0000000000000000" hex"0000000000000000"
hex"0000000000000005";
uint256 length0 = 1;
uint256 numVars0 = 2;
uint256 modulus0 = 10007;
uint256 result0;
// solhint-disable-next-line no-inline-assembly
assembly {
// IMPORT-YUL ../base/LagrangeBasisEvaluation.sol:compute_truncated_lagrange_basis_sum
// solhint-disable-next-line no-empty-blocks
function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result {}
// END-IMPORT-YUL
result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0)
}
assert(result0 == 4);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,47 +74,33 @@ where

/// Given the point `point` (or `a`) with length nu, we can evaluate the lagrange basis of length 2^nu at that point.
/// This is what [`super::compute_evaluation_vector`] does.
///
/// NOTE: if length is greater than 2^nu, this function will pad `point` with 0s, which
/// will result in padding the basis with 0s.
///
/// Call the resulting evaluation vector A. This function computes `sum A[i] for i in 0..length`. That is:
/// ```text
/// (1-a[0])(1-a[1])...(1-a[nu-1]) +
/// (a[0])(1-a[1])...(1-a[nu-1]) +
/// (1-a[0])(a[1])...(1-a[nu-1]) +
/// (a[0])(a[1])...(1-a[nu-1]) + ...
/// ```
/// # Panics
/// Panics if:
/// - The length is greater than `1` when `point` is empty.
/// - The length is greater than the maximum allowed for the given number of points, which is `2^(nu - 1)`
/// where `nu` is the number of elements in `point`.
pub fn compute_truncated_lagrange_basis_sum<F>(length: usize, point: &[F]) -> F
where
F: One + Zero + Mul<Output = F> + Add<Output = F> + Sub<Output = F> + Copy,
F: One + Zero + Mul<Output = F> + Sub<Output = F> + Copy,
{
let nu = point.len();
if nu == 0 {
assert!(length <= 1);
if length == 1 {
F::one()
} else {
F::zero()
}
if length >= 1 << point.len() {
F::one()
} else {
// Note: this is essentially the same as the inner production version.
// The only different is that the full sum is always 1, regardless of any inputs.

let first_half_term = F::one() - point[nu - 1];
let second_half_term = point[nu - 1];
let half_full_length = 1 << (nu - 1);
let sub_part_length = if length >= half_full_length {
length - half_full_length
} else {
length
};
let sub_part = compute_truncated_lagrange_basis_sum(sub_part_length, &point[..nu - 1]);
if length >= half_full_length {
first_half_term + sub_part * second_half_term
} else {
sub_part * first_half_term
}
point
.iter()
.enumerate()
.fold(F::zero(), |chi, (i, &alpha)| {
if (length >> i) & 1 == 0 {
chi * (F::one() - alpha)
} else {
F::one() - (F::one() - chi) * alpha
}
})
}
}
171 changes: 171 additions & 0 deletions crates/proof-of-sql/utils/yul-preprocessor/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//! This binary applies a preprocessing step to Solidity files that allows for importing Yul code from other files.
use clap::Parser;
use snafu::Snafu;
use std::{
fs::{self, File},
io::{self, BufRead, BufWriter, Write},
path::Path,
};

const IMPORT_YUL: &str = "// IMPORT-YUL";
const END_IMPORT_YUL: &str = "// END-IMPORT-YUL";
const START_YUL: &str = "// START-YUL";
const END_YUL: &str = "// END-YUL";
const IMPORTED_YUL: &str = "// IMPORTED-YUL";
const END_IMPORTED_YUL: &str = "// END-IMPORTED-YUL";

#[derive(Debug, Snafu)]
enum Error {
#[snafu(transparent)]
Io { source: io::Error },
#[snafu(display("Ill-formed IMPORT-YUL statement at line {line}"))]
IllFormedImportYul { line: usize },
#[snafu(display("Unmatched END-IMPORT-YUL at line {line}"))]
UnmatchedEndImportYul { line: usize },
#[snafu(display("Unmatched IMPORT-YUL at line"))]
UnmatchedImportYul,
#[snafu(display("Function {function_name} not found in file {file_path}"))]
FunctionNotFound {
function_name: String,
file_path: String,
},
}

/// A preprocessor for Solidity files to import Yul code
///
/// This tool processes a given file or directory, replacing the import statements with the corresponding Yul code.
///
/// # Usage
///
/// The Yul code should be wrapped in `// START-YUL <function_name>` and `// END-YUL` comments in the source files.
/// The import statement should be in the form `// IMPORT-YUL <file_path>:<function_name>` in the target Solidity files.
///
/// # Example
///
/// Given a Solidity file `example.psol` with the following content:
///
/// ```solidity
/// // IMPORT-YUL yul_code.sol:my_function
/// // END-IMPORT-YUL
/// ```
///
/// And a Yul file `yul_code.sol` with the following content:
///
/// ```solidity
/// // START-YUL my_function
/// function my_function() -> result {
/// // Yul code here
/// }
/// // END-YUL
/// ```
///
/// Running the binary will produce an output file `example.p.sol` with the following content:
///
/// ```solidity
/// // IMPORTED-YUL yul_code.sol:my_function
/// function my_function() -> result {
/// // Yul code here
/// }
/// // END-IMPORTED-YUL
/// ```
#[derive(Parser, Debug)]
#[command(about, long_about)]
struct Args {
/// The path to the file or directory to process
path: String,
}

fn main() -> Result<(), Error> {
let args = Args::parse();
process_path(Path::new(&args.path))?;
Ok(())
}

fn process_path(path: &Path) -> Result<(), Error> {
if path.is_dir() {
for entry in fs::read_dir(path)? {
process_path(&entry?.path())?;
}
} else if path.extension().and_then(|ext| ext.to_str()) == Some("psol") {
process_file(path)?;
}
Ok(())
}

fn process_file(path: &Path) -> Result<(), Error> {
let file = File::open(path)?;
let reader = io::BufReader::new(file);
let mut output_lines = Vec::new();
let mut inside_import = false;
let mut import_file = String::new();
let mut function_name = String::new();
let base_path = path.parent().unwrap_or_else(|| Path::new(""));

for (line_number, line) in reader.lines().enumerate() {
let line = line?;
if let Some(import_pos) = line.find(IMPORT_YUL) {
let parts: Vec<&str> = line[import_pos + IMPORT_YUL.len()..].split(':').collect();
if parts.len() != 2 {
return Err(Error::IllFormedImportYul {
line: line_number + 1,
});
}
inside_import = true;
import_file = parts[0].trim().to_string();
function_name = parts[1].trim().to_string();
} else if line.contains(END_IMPORT_YUL) {
if !inside_import {
return Err(Error::UnmatchedEndImportYul {
line: line_number + 1,
});
}
let function_lines = extract_function(&base_path.join(&import_file), &function_name)?;
output_lines.push(format!("{IMPORTED_YUL} {import_file}:{function_name}"));
output_lines.extend(function_lines);
output_lines.push(END_IMPORTED_YUL.to_string());
inside_import = false;
} else if !inside_import {
output_lines.push(line);
}
}

if inside_import {
return Err(Error::UnmatchedImportYul);
}

let file = File::create(path.with_extension("p.sol"))?;
let mut writer = BufWriter::new(file);
for line in output_lines {
writeln!(writer, "{line}")?;
}

Ok(())
}

fn extract_function(file_path: &Path, function_name: &str) -> Result<Vec<String>, Error> {
let file = File::open(file_path)?;
let reader = io::BufReader::new(file);
let mut function_lines = Vec::new();
let mut inside_function = false;

for line in reader.lines() {
let line = line?;
if line.contains(&format!("{START_YUL} {function_name}")) {
inside_function = true;
} else if line.contains(END_YUL) {
break;
} else if inside_function {
function_lines.push(line);
}
}

if !inside_function {
return Err(Error::FunctionNotFound {
function_name: function_name.to_string(),
file_path: file_path.display().to_string(),
});
}

Ok(function_lines)
}

0 comments on commit 792aaf3

Please sign in to comment.