Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sha256): Perform compression per block and utilize ROM instead of RAM when setting up the message block #5760

Merged
merged 19 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 172 additions & 46 deletions noir_stdlib/src/hash/sha256.nr
Original file line number Diff line number Diff line change
Expand Up @@ -17,82 +17,207 @@ pub fn digest<let N: u32>(msg: [u8; N]) -> [u8; 32] {
sha256_var(msg, N as u64)
}

// Convert 64-byte array to array of 16 u32s
fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
let mut msg32: [u32; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..4 {
msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field;
}
msg32[15 - i] = msg_field as u32;
}

msg32
}

unconstrained fn build_msg_block_iter<let N: u32>(
msg: [u8; N],
message_size: u64,
mut msg_block: [u8; 64]
) -> ([u8; 64], u64) {
let mut msg_byte_ptr: u64 = 0; // Message byte pointer
for k in 0..N {
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
if k as u64 < message_size {
msg_block[msg_byte_ptr] = msg[k];
msg_byte_ptr = msg_byte_ptr + 1;

if msg_byte_ptr == 64 {
msg_byte_ptr = 0;
}
}
}
(msg_block, msg_byte_ptr)
}

// Verify the block we are compressing was appropriately constructed
fn verify_msg_block<let N: u32>(msg: [u8; N], message_size: u64, msg_block: [u8; 64]) -> u64 {
let mut msg_byte_ptr: u64 = 0; // Message byte pointer
for k in 0..N {
if k as u64 < message_size {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
msg_byte_ptr = msg_byte_ptr + 1;
if msg_byte_ptr == 64 {
// Enough to hash block
msg_byte_ptr = 0;
}
} else {
// Need to assert over the msg block in the else case as well
if N < 64 {
assert_eq(msg_block[msg_byte_ptr], 0);
} else {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
}
}
}
msg_byte_ptr
}

// Variable size SHA-256 hash
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
let num_blocks = N / 64;
let mut msg_block: [u8; 64] = [0; 64];
let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value
let mut i: u64 = 0; // Message byte pointer
for k in 0..N {
if k as u64 < message_size {
// Populate msg_block
msg_block[i] = msg[k];
i = i + 1;
if i == 64 {
// Enough to hash block
h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h);
let mut msg_byte_ptr = 0;

i = 0;
}
if num_blocks == 0 {
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
unsafe {
let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block);
msg_block = new_msg_block;
msg_byte_ptr = new_msg_byte_ptr;
}

if !crate::runtime::is_unconstrained() {
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block);
}
}

for _ in 0..num_blocks {
unsafe {
let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block);
msg_block = new_msg_block;
msg_byte_ptr = new_msg_byte_ptr;
}
if !crate::runtime::is_unconstrained() {
// Verify the block we are compressing was appropriately constructed
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block);
}

// Hash the block
h = sha256_compression(msg_u8_to_u32(msg_block), h);
}

let last_block = msg_block;
// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[i] = 1 << 7;
i = i + 1;
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[msg_byte_ptr] = 1 << 7;
msg_byte_ptr = msg_byte_ptr + 1;
unsafe {
let (new_msg_block, new_msg_byte_ptr)= pad_msg_block(msg_block, msg_byte_ptr);
msg_block = new_msg_block;
if crate::runtime::is_unconstrained() {
msg_byte_ptr = new_msg_byte_ptr;
}
}

if !crate::runtime::is_unconstrained() {
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
// Not enough bits (64) to store length. Fill up with zeros.
for _i in 57..64 {
if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 {
assert_eq(msg_block[msg_byte_ptr], 0);
msg_byte_ptr += 1;
}
}
}

if msg_byte_ptr >= 57 {
h = sha256_compression(msg_u8_to_u32(msg_block), h);

msg_byte_ptr = 0;
}

unsafe {
msg_block = attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size);
}

if !crate::runtime::is_unconstrained() {
for i in 0..64 {
if i as u64 < msg_byte_ptr - 1 {
assert_eq(msg_block[i], last_block[i]);
}
}
assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7);

let len = 8 * message_size;
let len_bytes = (len as Field).to_le_bytes(8);
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
for _ in 0..64 {
if msg_byte_ptr < 56 {
assert_eq(msg_block[msg_byte_ptr], 0);
msg_byte_ptr = msg_byte_ptr + 1;
}
}

let mut block_idx = 0;
for i in 56..64 {
assert_eq(msg_block[63 - block_idx], len_bytes[i - 56]);
block_idx = block_idx + 1;
}
}

hash_final_block(msg_block, h)
}

unconstrained fn pad_msg_block<let N: u32>(
mut msg_block: [u8; 64],
mut msg_byte_ptr: u64
) -> ([u8; 64], u64) {
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
if i >= 57 {
if msg_byte_ptr >= 57 {
// Not enough bits (64) to store length. Fill up with zeros.
if i < 64 {
for _i in 57..64 {
if i <= 63 {
msg_block[i] = 0;
i += 1;
if msg_byte_ptr < 64 {
for _ in 57..64 {
if msg_byte_ptr <= 63 {
msg_block[msg_byte_ptr] = 0;
msg_byte_ptr += 1;
}
}
}
h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h);

i = 0;
}
(msg_block, msg_byte_ptr)
}

unconstrained fn attach_len_to_msg_block<let N: u32>(
mut msg_block: [u8; 64],
mut msg_byte_ptr: u64,
message_size: u64
) -> [u8; 64] {
let len = 8 * message_size;
let len_bytes = (len as Field).to_le_bytes(8);
for _i in 0..64 {
// In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56).
if i < 56 {
msg_block[i] = 0;
i = i + 1;
} else if i < 64 {
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
if msg_byte_ptr < 56 {
msg_block[msg_byte_ptr] = 0;
msg_byte_ptr = msg_byte_ptr + 1;
} else if msg_byte_ptr < 64 {
for j in 0..8 {
msg_block[63 - j] = len_bytes[j];
}
i += 8;
}
}
hash_final_block(msg_block, h)
}

// Convert 64-byte array to array of 16 u32s
fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
let mut msg32: [u32; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..4 {
msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field;
msg_byte_ptr += 8;
}
msg32[15 - i] = msg_field as u32;
}

msg32
msg_block
}

fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] {
let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes

// Hash final padded block
state = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), state);
state = sha256_compression(msg_u8_to_u32(msg_block), state);

// Return final hash as byte array
for j in 0..8 {
Expand All @@ -104,3 +229,4 @@ fn hash_final_block(msg_block: [u8; 64], mut state: [u32; 8]) -> [u8; 32] {

out_h
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "sha256_var_size_regression"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
enable = [true, false]
foo = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
toggle = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
global NUM_HASHES = 2;

fn main(foo: [u8; 95], toggle: bool, enable: [bool; NUM_HASHES]) {
let mut result = [[0; 32]; NUM_HASHES];
let mut const_result = [[0; 32]; NUM_HASHES];
let size: Field = 93 + toggle as Field * 2;
for i in 0..NUM_HASHES {
if enable[i] {
result[i] = std::sha256::sha256_var(foo, size as u64);
const_result[i] = std::sha256::sha256_var(foo, 93);
}
}

for i in 0..NUM_HASHES {
assert_eq(result[i], const_result[i]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "sha256_var_witness_const_regression"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
input = [0, 0]
toggle = false
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main(input: [u8; 2], toggle: bool) {
let size: Field = 1 + toggle as Field;
assert(!toggle);

let variable_sha = std::sha256::sha256_var(input, size as u64);
let constant_sha = std::sha256::sha256_var(input, 1);

assert_eq(variable_sha, constant_sha);
}
Loading