Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: yul lagrange basis evaluation #442

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ jobs:
with:
version: nightly
- name: Run tests
run: forge test --root=crates/proof-of-sql --summary --detailed
run: |
cargo run --bin yul_preprocessor crates/proof-of-sql
forge test --root=crates/proof-of-sql --summary --detailed
solhint:
name: solhint
Expand All @@ -245,4 +247,4 @@ jobs:
- name: Install solhint
run: npm install -g solhint
- name: Run tests
run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' -w 0
run: solhint -c 'crates/proof-of-sql/.solhint.json' 'crates/proof-of-sql/**/*.sol' 'crates/proof-of-sql/**/*.psol' -w 0
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ cache

# any output files from generating public params
output/

*.p.sol
5 changes: 5 additions & 0 deletions crates/proof-of-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ name = "commitment-utility"
path = "utils/commitment-utility/main.rs"
required-features = [ "std", "blitzar"]

[[bin]]
name = "yul_preprocessor"
path = "utils/yul-preprocessor/main.rs"
required-features = [ "std" ]

[[example]]
name = "hello_world"
required-features = ["test"]
Expand Down
76 changes: 76 additions & 0 deletions crates/proof-of-sql/sol_src/base/LagrangeBasisEvaluation.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

contract LagrangeBasisEvaluation {
function computeTruncatedLagrangeBasisSum(uint256 length0, bytes memory point0, uint256 numVars0, uint256 modulus0)
public
pure
returns (uint256 result0)
{
// solhint-disable-next-line no-inline-assembly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps you want to add this to your solhint: "no-inline-assembly": "off", to avoid repeatedly using // solhint-disable-next-line rule

assembly {
// START-YUL compute_truncated_lagrange_basis_sum
function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result {
let ONE := add(modulus, 1)
// result := 0 // implicitly set by the EVM

// Invariant that holds within the for loop:
// 0 <= result <= modulus + 1
// This invariant reduces modulus operations.
for {} num_vars {} {
switch and(length, 1)
case 0 { result := mulmod(result, sub(ONE, mod(mload(point), modulus)), modulus) }
default { result := sub(ONE, mulmod(sub(ONE, result), mload(point), modulus)) }
num_vars := sub(num_vars, 1)
length := shr(1, length)
point := add(point, 32)
}
switch length
case 0 { result := mod(result, modulus) }
default { result := 1 }
}
// END-YUL
result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0)
}
}

uint256 private constant TEST_MODULUS = 10007;

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith0Variables() public pure {
bytes memory point = hex"";
assert(computeTruncatedLagrangeBasisSum(1, point, 0, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(0, point, 0, TEST_MODULUS) == 0);
}

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith1Variables() public pure {
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002";
assert(computeTruncatedLagrangeBasisSum(2, point, 1, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(1, point, 1, TEST_MODULUS) == TEST_MODULUS - 1);
assert(computeTruncatedLagrangeBasisSum(0, point, 1, TEST_MODULUS) == 0);
}

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith2Variables() public pure {
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002"
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005";
assert(computeTruncatedLagrangeBasisSum(4, point, 2, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(3, point, 2, TEST_MODULUS) == TEST_MODULUS - 9);
assert(computeTruncatedLagrangeBasisSum(2, point, 2, TEST_MODULUS) == TEST_MODULUS - 4);
assert(computeTruncatedLagrangeBasisSum(1, point, 2, TEST_MODULUS) == 4);
assert(computeTruncatedLagrangeBasisSum(0, point, 2, TEST_MODULUS) == 0);
}

function testComputeTruncatedLagrangeBasisSumGivesCorrectValuesWith3Variables() public pure {
bytes memory point = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000002"
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000005"
hex"0000000000000000" hex"0000000000000000" hex"0000000000000000" hex"0000000000000007";
assert(computeTruncatedLagrangeBasisSum(8, point, 3, TEST_MODULUS) == 1);
assert(computeTruncatedLagrangeBasisSum(7, point, 3, TEST_MODULUS) == TEST_MODULUS - 69);
assert(computeTruncatedLagrangeBasisSum(6, point, 3, TEST_MODULUS) == TEST_MODULUS - 34);
assert(computeTruncatedLagrangeBasisSum(5, point, 3, TEST_MODULUS) == 22);
assert(computeTruncatedLagrangeBasisSum(4, point, 3, TEST_MODULUS) == TEST_MODULUS - 6);
assert(computeTruncatedLagrangeBasisSum(3, point, 3, TEST_MODULUS) == 54);
assert(computeTruncatedLagrangeBasisSum(2, point, 3, TEST_MODULUS) == 24);
assert(computeTruncatedLagrangeBasisSum(1, point, 3, TEST_MODULUS) == TEST_MODULUS - 24);
assert(computeTruncatedLagrangeBasisSum(0, point, 3, TEST_MODULUS) == 0);
}
}
23 changes: 23 additions & 0 deletions crates/proof-of-sql/sol_src/tests/TestYulImport.t.psol
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

library TestScript {
function testWeCanImportYulFromAnotherFile() public pure {
bytes memory point0 = hex"0000000000000000" hex"0000000000000000" hex"0000000000000000"
hex"0000000000000002" hex"0000000000000000" hex"0000000000000000" hex"0000000000000000"
hex"0000000000000005";
uint256 length0 = 1;
uint256 numVars0 = 2;
uint256 modulus0 = 10007;
uint256 result0;
// solhint-disable-next-line no-inline-assembly
assembly {
// IMPORT-YUL ../base/LagrangeBasisEvaluation.sol:compute_truncated_lagrange_basis_sum
// solhint-disable-next-line no-empty-blocks
function compute_truncated_lagrange_basis_sum(length, point, num_vars, modulus) -> result {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line isn't needed right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. It's needed for the sake of the linter, so it knows what the function signature is.

// END-IMPORT-YUL
result0 := compute_truncated_lagrange_basis_sum(length0, add(point0, 32), numVars0, modulus0)
}
assert(result0 == 4);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,47 +74,33 @@ where

/// Given the point `point` (or `a`) with length nu, we can evaluate the lagrange basis of length 2^nu at that point.
/// This is what [`super::compute_evaluation_vector`] does.
///
/// NOTE: if length is greater than 2^nu, this function will pad `point` with 0s, which
/// will result in padding the basis with 0s.
///
/// Call the resulting evaluation vector A. This function computes `sum A[i] for i in 0..length`. That is:
/// ```text
/// (1-a[0])(1-a[1])...(1-a[nu-1]) +
/// (a[0])(1-a[1])...(1-a[nu-1]) +
/// (1-a[0])(a[1])...(1-a[nu-1]) +
/// (a[0])(a[1])...(1-a[nu-1]) + ...
/// ```
/// # Panics
/// Panics if:
/// - The length is greater than `1` when `point` is empty.
/// - The length is greater than the maximum allowed for the given number of points, which is `2^(nu - 1)`
/// where `nu` is the number of elements in `point`.
pub fn compute_truncated_lagrange_basis_sum<F>(length: usize, point: &[F]) -> F
where
F: One + Zero + Mul<Output = F> + Add<Output = F> + Sub<Output = F> + Copy,
F: One + Zero + Mul<Output = F> + Sub<Output = F> + Copy,
{
let nu = point.len();
if nu == 0 {
assert!(length <= 1);
if length == 1 {
F::one()
} else {
F::zero()
}
if length >= 1 << point.len() {
F::one()
} else {
// Note: this is essentially the same as the inner production version.
// The only different is that the full sum is always 1, regardless of any inputs.

let first_half_term = F::one() - point[nu - 1];
let second_half_term = point[nu - 1];
let half_full_length = 1 << (nu - 1);
let sub_part_length = if length >= half_full_length {
length - half_full_length
} else {
length
};
let sub_part = compute_truncated_lagrange_basis_sum(sub_part_length, &point[..nu - 1]);
if length >= half_full_length {
first_half_term + sub_part * second_half_term
} else {
sub_part * first_half_term
}
point
.iter()
.enumerate()
.fold(F::zero(), |chi, (i, &alpha)| {
if (length >> i) & 1 == 0 {
chi * (F::one() - alpha)
} else {
F::one() - (F::one() - chi) * alpha
}
})
}
}
171 changes: 171 additions & 0 deletions crates/proof-of-sql/utils/yul-preprocessor/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//! This binary applies a preprocessing step to Solidity files that allows for importing Yul code from other files.

use clap::Parser;
use snafu::Snafu;
use std::{
fs::{self, File},
io::{self, BufRead, BufWriter, Write},
path::Path,
};

const IMPORT_YUL: &str = "// IMPORT-YUL";
const END_IMPORT_YUL: &str = "// END-IMPORT-YUL";
const START_YUL: &str = "// START-YUL";
const END_YUL: &str = "// END-YUL";
const IMPORTED_YUL: &str = "// IMPORTED-YUL";
const END_IMPORTED_YUL: &str = "// END-IMPORTED-YUL";

#[derive(Debug, Snafu)]
enum Error {
#[snafu(transparent)]
Io { source: io::Error },
#[snafu(display("Ill-formed IMPORT-YUL statement at line {line}"))]
IllFormedImportYul { line: usize },
#[snafu(display("Unmatched END-IMPORT-YUL at line {line}"))]
UnmatchedEndImportYul { line: usize },
#[snafu(display("Unmatched IMPORT-YUL at line"))]
UnmatchedImportYul,
#[snafu(display("Function {function_name} not found in file {file_path}"))]
FunctionNotFound {
function_name: String,
file_path: String,
},
}

/// A preprocessor for Solidity files to import Yul code
///
/// This tool processes a given file or directory, replacing the import statements with the corresponding Yul code.
///
/// # Usage
///
/// The Yul code should be wrapped in `// START-YUL <function_name>` and `// END-YUL` comments in the source files.
/// The import statement should be in the form `// IMPORT-YUL <file_path>:<function_name>` in the target Solidity files.
///
/// # Example
///
/// Given a Solidity file `example.psol` with the following content:
///
/// ```solidity
/// // IMPORT-YUL yul_code.sol:my_function
/// // END-IMPORT-YUL
/// ```
///
/// And a Yul file `yul_code.sol` with the following content:
///
/// ```solidity
/// // START-YUL my_function
/// function my_function() -> result {
/// // Yul code here
/// }
/// // END-YUL
/// ```
///
/// Running the binary will produce an output file `example.p.sol` with the following content:
///
/// ```solidity
/// // IMPORTED-YUL yul_code.sol:my_function
/// function my_function() -> result {
/// // Yul code here
/// }
/// // END-IMPORTED-YUL
/// ```
#[derive(Parser, Debug)]
#[command(about, long_about)]
struct Args {
/// The path to the file or directory to process
path: String,
}

fn main() -> Result<(), Error> {
let args = Args::parse();
process_path(Path::new(&args.path))?;
Ok(())
}

fn process_path(path: &Path) -> Result<(), Error> {
if path.is_dir() {
for entry in fs::read_dir(path)? {
process_path(&entry?.path())?;
}
} else if path.extension().and_then(|ext| ext.to_str()) == Some("psol") {
process_file(path)?;
}
Ok(())
}

fn process_file(path: &Path) -> Result<(), Error> {
let file = File::open(path)?;
let reader = io::BufReader::new(file);
let mut output_lines = Vec::new();
let mut inside_import = false;
let mut import_file = String::new();
let mut function_name = String::new();
let base_path = path.parent().unwrap_or_else(|| Path::new(""));

for (line_number, line) in reader.lines().enumerate() {
let line = line?;
if let Some(import_pos) = line.find(IMPORT_YUL) {
let parts: Vec<&str> = line[import_pos + IMPORT_YUL.len()..].split(':').collect();
if parts.len() != 2 {
return Err(Error::IllFormedImportYul {
line: line_number + 1,
});
}
inside_import = true;
import_file = parts[0].trim().to_string();
function_name = parts[1].trim().to_string();
} else if line.contains(END_IMPORT_YUL) {
if !inside_import {
return Err(Error::UnmatchedEndImportYul {
line: line_number + 1,
});
}
let function_lines = extract_function(&base_path.join(&import_file), &function_name)?;
output_lines.push(format!("{IMPORTED_YUL} {import_file}:{function_name}"));
output_lines.extend(function_lines);
output_lines.push(END_IMPORTED_YUL.to_string());
inside_import = false;
} else if !inside_import {
output_lines.push(line);
}
}

if inside_import {
return Err(Error::UnmatchedImportYul);
}

let file = File::create(path.with_extension("p.sol"))?;
let mut writer = BufWriter::new(file);
for line in output_lines {
writeln!(writer, "{line}")?;
}

Ok(())
}

fn extract_function(file_path: &Path, function_name: &str) -> Result<Vec<String>, Error> {
let file = File::open(file_path)?;
let reader = io::BufReader::new(file);
let mut function_lines = Vec::new();
let mut inside_function = false;

for line in reader.lines() {
let line = line?;
if line.contains(&format!("{START_YUL} {function_name}")) {
inside_function = true;
} else if line.contains(END_YUL) {
break;
} else if inside_function {
function_lines.push(line);
}
}

if !inside_function {
return Err(Error::FunctionNotFound {
function_name: function_name.to_string(),
file_path: file_path.display().to_string(),
});
}

Ok(function_lines)
}
Loading