Skip to content

Commit

Permalink
fix(sha256): Fix upper bound when building msg block and delay final …
Browse files Browse the repository at this point in the history
…block compression under certain cases (noir-lang#5838)

# Description

## Problem\*

Resolves noir-lang#5836 

## Summary\*

We accept a start index based upon the current block when parsing a
message. We should accurately base the upper bound to be based upon this
start index.

We also have special handling for building a message block but not
compressing it when the message is less than the block size. We need to
also do this handling for the last message block when we have a message
that is larger than the block size.

## Additional Context

~~sha256_var is currently getting warnings from the under constrained
check. It looks to only be happening on the new regression test added as
part of this PR that uses a larger message. The old sha256 tests do not
look to trigger these warnings which is strange. I am a bit unsure why I
am getting these warnings as msg block and msg block pointer are being
verified on each iteration of the loop.~~

<img width="965" alt="Screenshot 2024-08-27 at 12 04 58 PM"
src="https://github.com/user-attachments/assets/c71bff9d-bfea-4765-a134-91eca92c7806">

EDIT: This was only happening as my test was hashing constant values,
thus it was a dumb circuit. e.g the following:
```rust
fn main(result: pub [u8; 32]) {
    let headers = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116];
    let hash = std::hash::sha256_var(headers, headers.len() as u64);
    assert_eq(hash, result);
}
```
The message needs to come from the inputs and the under-constrained
warnings go away.

## Documentation\*

Check one:
- [X] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
vezenovm authored Aug 28, 2024
1 parent 5739904 commit 130b7b6
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,6 @@ impl<F: AcirField> AcirContext<F> {
name,
BlackBoxFunc::MultiScalarMul
| BlackBoxFunc::Keccakf1600
| BlackBoxFunc::Sha256Compression
| BlackBoxFunc::Blake2s
| BlackBoxFunc::Blake3
| BlackBoxFunc::AND
Expand Down
150 changes: 76 additions & 74 deletions noir_stdlib/src/hash/sha256.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::runtime::is_unconstrained;

// Implementation of SHA-256 mapping a byte array of variable length to
// 32 bytes.

Expand Down Expand Up @@ -32,21 +34,17 @@ fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
msg32
}

unconstrained fn build_msg_block_iter<let N: u32>(
msg: [u8; N],
message_size: u64,
mut msg_block: [u8; 64],
msg_start: u32
) -> ([u8; 64], u64) {
unconstrained fn build_msg_block_iter<let N: u32>(msg: [u8; N], message_size: u64, msg_start: u32) -> ([u8; 64], u64) {
let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
let mut msg_byte_ptr: u64 = 0; // Message byte pointer
for k in msg_start..N {
let mut msg_end = msg_start + BLOCK_SIZE;
if msg_end > N {
msg_end = N;
}
for k in msg_start..msg_end {
if k as u64 < message_size {
msg_block[msg_byte_ptr] = msg[k];
msg_byte_ptr = msg_byte_ptr + 1;

if msg_byte_ptr == 64 {
msg_byte_ptr = 0;
}
}
}
(msg_block, msg_byte_ptr)
Expand All @@ -60,27 +58,32 @@ fn verify_msg_block<let N: u32>(
msg_start: u32
) -> u64 {
let mut msg_byte_ptr: u64 = 0; // Message byte pointer
for k in msg_start..N {
let mut msg_end = msg_start + BLOCK_SIZE;
let mut extra_bytes = 0;
if msg_end > N {
msg_end = N;
extra_bytes = msg_end - N;
}

for k in msg_start..msg_end {
if k as u64 < message_size {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
msg_byte_ptr = msg_byte_ptr + 1;
if msg_byte_ptr == 64 {
// Enough to hash block
msg_byte_ptr = 0;
}
}
}

for i in 0..BLOCK_SIZE {
if i as u64 >= msg_byte_ptr {
assert_eq(msg_block[i], 0);
} else {
// Need to assert over the msg block in the else case as well
if N < 64 {
assert_eq(msg_block[msg_byte_ptr], 0);
} else {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
}
assert_eq(msg_block[i], msg[msg_start + i - extra_bytes]);
}
}

msg_byte_ptr
}

global BLOCK_SIZE = 64;
global ZERO = 0;

// Variable size SHA-256 hash
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
Expand All @@ -89,38 +92,55 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value
let mut msg_byte_ptr = 0; // Pointer into msg_block

if num_blocks == 0 {
unsafe {
let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, 0);
msg_block = new_msg_block;
for i in 0..num_blocks {
let (new_msg_block, new_msg_byte_ptr) = unsafe {
build_msg_block_iter(msg, message_size, BLOCK_SIZE * i)
};
msg_block = new_msg_block;

if !is_unconstrained() {
// Verify the block we are compressing was appropriately constructed
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i);
} else {
msg_byte_ptr = new_msg_byte_ptr;
}

if !crate::runtime::is_unconstrained() {
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, 0);
}
// Compress the block
h = sha256_compression(msg_u8_to_u32(msg_block), h);
}

for i in 0..num_blocks {
unsafe {
let (new_msg_block, new_msg_byte_ptr) = build_msg_block_iter(msg, message_size, msg_block, BLOCK_SIZE * i);
msg_block = new_msg_block;
let modulo = N % BLOCK_SIZE;
// Handle setup of the final msg block.
// This case is only hit if the msg is less than the block size,
// or our message cannot be evenly split into blocks.
if modulo != 0 {
let (new_msg_block, new_msg_byte_ptr) = unsafe {
build_msg_block_iter(msg, message_size, BLOCK_SIZE * num_blocks)
};
msg_block = new_msg_block;

if !is_unconstrained() {
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * num_blocks);
} else {
msg_byte_ptr = new_msg_byte_ptr;
}
if !crate::runtime::is_unconstrained() {
// Verify the block we are compressing was appropriately constructed
msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, BLOCK_SIZE * i);
}
}

// Hash the block
h = sha256_compression(msg_u8_to_u32(msg_block), h);
if msg_byte_ptr == BLOCK_SIZE as u64 {
msg_byte_ptr = 0;
}

let last_block = msg_block;
// This variable is used to get around the compiler under-constrained check giving a warning.
// We want to check against a constant zero, but if it does not come from the circuit inputs
// or return values the compiler check will issue a warning.
let zero = msg_block[0] - msg_block[0];

// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
msg_block[msg_byte_ptr] = 1 << 7;
let last_block = msg_block;
msg_byte_ptr = msg_byte_ptr + 1;

unsafe {
let (new_msg_block, new_msg_byte_ptr) = pad_msg_block(msg_block, msg_byte_ptr);
msg_block = new_msg_block;
Expand All @@ -131,18 +151,15 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {

if !crate::runtime::is_unconstrained() {
for i in 0..64 {
if i as u64 < msg_byte_ptr - 1 {
assert_eq(msg_block[i], last_block[i]);
}
assert_eq(msg_block[i], last_block[i]);
}
assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7);

// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
// Not enough bits (64) to store length. Fill up with zeros.
for _i in 57..64 {
if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 {
assert_eq(msg_block[msg_byte_ptr], 0);
assert_eq(msg_block[msg_byte_ptr], zero);
msg_byte_ptr += 1;
}
}
Expand All @@ -154,34 +171,23 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
msg_byte_ptr = 0;
}

unsafe {
msg_block = attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size);
}
msg_block = unsafe {
attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size)
};

if !crate::runtime::is_unconstrained() {
if msg_byte_ptr != 0 {
for i in 0..64 {
if i as u64 < msg_byte_ptr - 1 {
assert_eq(msg_block[i], last_block[i]);
}
for i in 0..56 {
if i < msg_byte_ptr {
assert_eq(msg_block[i], last_block[i]);
} else {
assert_eq(msg_block[i], zero);
}
assert_eq(msg_block[msg_byte_ptr - 1], 1 << 7);
}

let len = 8 * message_size;
let len_bytes = (len as Field).to_le_bytes(8);
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
for _ in 0..64 {
if msg_byte_ptr < 56 {
assert_eq(msg_block[msg_byte_ptr], 0);
msg_byte_ptr = msg_byte_ptr + 1;
}
}

let mut block_idx = 0;
let len_bytes = (len as Field).to_be_bytes(8);
for i in 56..64 {
assert_eq(msg_block[63 - block_idx], len_bytes[i - 56]);
block_idx = block_idx + 1;
assert_eq(msg_block[i], len_bytes[i - 56]);
}
}

Expand All @@ -205,21 +211,17 @@ unconstrained fn pad_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64) -
(msg_block, msg_byte_ptr)
}

unconstrained fn attach_len_to_msg_block(
mut msg_block: [u8; 64],
mut msg_byte_ptr: u64,
message_size: u64
) -> [u8; 64] {
unconstrained fn attach_len_to_msg_block(mut msg_block: [u8; 64], mut msg_byte_ptr: u64, message_size: u64) -> [u8; 64] {
let len = 8 * message_size;
let len_bytes = (len as Field).to_le_bytes(8);
let len_bytes = (len as Field).to_be_bytes(8);
for _i in 0..64 {
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
if msg_byte_ptr < 56 {
msg_block[msg_byte_ptr] = 0;
msg_byte_ptr = msg_byte_ptr + 1;
} else if msg_byte_ptr < 64 {
for j in 0..8 {
msg_block[63 - j] = len_bytes[j];
msg_block[msg_byte_ptr + j] = len_bytes[j];
}
msg_byte_ptr += 8;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "sha256_var_witness_const_regression"
name = "sha256_regression"
type = "bin"
authors = [""]
compiler_version = ">=0.33.0"
Expand Down
9 changes: 9 additions & 0 deletions test_programs/execution_success/sha256_regression/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
msg_just_over_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116]
msg_multiple_of_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99]
msg_just_under_block = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59]
msg_big_not_block_multiple = [102, 114, 111, 109, 58, 114, 117, 110, 110, 105, 101, 114, 46, 108, 101, 97, 103, 117, 101, 115, 46, 48, 106, 64, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 58, 116, 101, 120, 116, 47, 112, 108, 97, 105, 110, 59, 32, 99, 104, 97, 114, 115, 101, 116, 61, 117, 115, 45, 97, 115, 99, 105, 105, 13, 10, 109, 105, 109, 101, 45, 118, 101, 114, 115, 105, 111, 110, 58, 49, 46, 48, 32, 40, 77, 97, 99, 32, 79, 83, 32, 88, 32, 77, 97, 105, 108, 32, 49, 54, 46, 48, 32, 92, 40, 51, 55, 51, 49, 46, 53, 48, 48, 46, 50, 51, 49, 92, 41, 41, 13, 10, 115, 117, 98, 106, 101, 99, 116, 58, 72, 101, 108, 108, 111, 13, 10, 109, 101, 115, 115, 97, 103, 101, 45, 105, 100, 58, 60, 56, 70, 56, 49, 57, 68, 51, 50, 45, 66, 54, 65, 67, 45, 52, 56, 57, 68, 45, 57, 55, 55, 70, 45, 52, 51, 56, 66, 66, 67, 52, 67, 65, 66, 50, 55, 64, 109, 101, 46, 99, 111, 109, 62, 13, 10, 100, 97, 116, 101, 58, 83, 97, 116, 44, 32, 50, 54, 32, 65, 117, 103, 32, 50, 48, 50, 51, 32, 49, 50, 58, 50, 53, 58, 50, 50, 32, 43, 48, 52, 48, 48, 13, 10, 116, 111, 58, 122, 107, 101, 119, 116, 101, 115, 116, 64, 103, 109, 97, 105, 108, 46, 99, 111, 109, 13, 10, 100, 107, 105, 109, 45, 115, 105, 103, 110, 97, 116, 117, 114, 101, 58, 118, 61, 49, 59, 32, 97, 61, 114, 115, 97, 45, 115, 104, 97, 50, 53, 54, 59, 32, 99, 61, 114, 101, 108, 97, 120, 101, 100, 47, 114, 101, 108, 97, 120, 101, 100, 59, 32, 100, 61, 105, 99, 108, 111, 117, 100, 46, 99, 111, 109, 59, 32, 115, 61, 49, 97, 49, 104, 97, 105, 59, 32, 116, 61, 49, 54, 57, 51, 48, 51, 56, 51, 51, 55, 59, 32, 98, 104, 61, 55, 120, 81, 77, 68, 117, 111, 86, 86, 85, 52, 109, 48, 87, 48, 87, 82, 86, 83, 114, 86, 88, 77, 101, 71, 83, 73, 65, 83, 115, 110, 117, 99, 75, 57, 100, 74, 115, 114, 99, 43, 118, 85, 61, 59, 32, 104, 61, 102, 114, 111, 109, 58, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 77, 105, 109, 101, 45, 86, 101, 114, 115, 105, 111, 110, 58, 83, 117, 98, 106, 101, 99, 116, 58, 77, 101, 115, 115, 97, 103, 101, 45, 73, 100, 58, 68, 97, 116, 101, 58, 116, 111, 59, 32, 98, 61]
# Results matched against ethers library
result_just_over_block = [91, 122, 146, 93, 52, 109, 133, 148, 171, 61, 156, 70, 189, 238, 153, 7, 222, 184, 94, 24, 65, 114, 192, 244, 207, 199, 87, 232, 192, 224, 171, 207]
result_multiple_of_block = [116, 90, 151, 31, 78, 22, 138, 180, 211, 189, 69, 76, 227, 200, 155, 29, 59, 123, 154, 60, 47, 153, 203, 129, 157, 251, 48, 2, 79, 11, 65, 47]
result_just_under_block = [143, 140, 76, 173, 222, 123, 102, 68, 70, 149, 207, 43, 39, 61, 34, 79, 216, 252, 213, 165, 74, 16, 110, 74, 29, 64, 138, 167, 30, 1, 9, 119]
result_big = [112, 144, 73, 182, 208, 98, 9, 238, 54, 229, 61, 145, 222, 17, 72, 62, 148, 222, 186, 55, 192, 82, 220, 35, 66, 47, 193, 200, 22, 38, 26, 186]
26 changes: 26 additions & 0 deletions test_programs/execution_success/sha256_regression/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
fn main(
msg_just_over_block: [u8; 68],
result_just_over_block: pub [u8; 32],
msg_multiple_of_block: [u8; 448],
result_multiple_of_block: pub [u8; 32],
// We want to make sure we are testing a message with a size >= 57 but < 64
msg_just_under_block: [u8; 60],
result_just_under_block: pub [u8; 32],
msg_big_not_block_multiple: [u8; 472],
result_big: pub [u8; 32]
) {
let hash = std::hash::sha256_var(msg_just_over_block, msg_just_over_block.len() as u64);
assert_eq(hash, result_just_over_block);

let hash = std::hash::sha256_var(msg_multiple_of_block, msg_multiple_of_block.len() as u64);
assert_eq(hash, result_multiple_of_block);

let hash = std::hash::sha256_var(msg_just_under_block, msg_just_under_block.len() as u64);
assert_eq(hash, result_just_under_block);

let hash = std::hash::sha256_var(
msg_big_not_block_multiple,
msg_big_not_block_multiple.len() as u64
);
assert_eq(hash, result_big);
}

This file was deleted.

This file was deleted.

0 comments on commit 130b7b6

Please sign in to comment.