Skip to content

Commit

Permalink
feat: stream cipher zk (#384)
Browse files Browse the repository at this point in the history
* feat: stream cipher zk

* bump mpz version to ecb8c54
  • Loading branch information
sinui0 authored Nov 18, 2023
1 parent 6270dd8 commit 786195b
Show file tree
Hide file tree
Showing 16 changed files with 638 additions and 580 deletions.
4 changes: 2 additions & 2 deletions components/aead/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ tracing = [
tlsn-block-cipher = { path = "../cipher/block-cipher" }
tlsn-stream-cipher = { path = "../cipher/stream-cipher" }
tlsn-universal-hash = { path = "../universal-hash" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" }

async-trait = "0.1"
Expand Down
4 changes: 2 additions & 2 deletions components/cipher/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ resolver = "2"

[workspace.dependencies]
# tlsn
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "1ac6779" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ecb8c54" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "8d8ffe1" }

# crypto
Expand Down
4 changes: 2 additions & 2 deletions components/cipher/stream-cipher/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ mpz-circuits.workspace = true
mpz-garble.workspace = true
tlsn-utils.workspace = true
aes.workspace = true
ctr.workspace = true
cipher.workspace = true
async-trait.workspace = true
thiserror.workspace = true
derive_builder.workspace = true
tracing = { workspace = true, optional = true }

[dev-dependencies]
ctr.workspace = true
cipher.workspace = true
futures.workspace = true
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
rstest = { workspace = true, features = ["async-timeout"] }
Expand Down
76 changes: 72 additions & 4 deletions components/cipher/stream-cipher/benches/mock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use criterion::{criterion_group, criterion_main, Criterion, Throughput};

use mpz_garble::{protocol::deap::mock::create_mock_deap_vm, Memory, Vm};
use tlsn_stream_cipher::{Aes128Ctr, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder};
use tlsn_stream_cipher::{
Aes128Ctr, CtrCircuit, MpcStreamCipher, StreamCipher, StreamCipherConfigBuilder,
};

async fn bench_stream_cipher_encrypt(thread_count: usize, len: usize) {
let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await;
Expand Down Expand Up @@ -57,20 +59,86 @@ async fn bench_stream_cipher_encrypt(thread_count: usize, len: usize) {
_ = tokio::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap();
}

async fn bench_stream_cipher_zk(thread_count: usize, len: usize) {
let (mut leader_vm, mut follower_vm) = create_mock_deap_vm("test").await;

let key = [0u8; 16];
let iv = [0u8; 4];

let leader_thread = leader_vm.new_thread("key_config").await.unwrap();
let leader_key = leader_thread.new_public_input::<[u8; 16]>("key").unwrap();
let leader_iv = leader_thread.new_public_input::<[u8; 4]>("iv").unwrap();

leader_thread.assign(&leader_key, key).unwrap();
leader_thread.assign(&leader_iv, iv).unwrap();

let follower_thread = follower_vm.new_thread("key_config").await.unwrap();
let follower_key = follower_thread.new_public_input::<[u8; 16]>("key").unwrap();
let follower_iv = follower_thread.new_public_input::<[u8; 4]>("iv").unwrap();

follower_thread.assign(&follower_key, key).unwrap();
follower_thread.assign(&follower_iv, iv).unwrap();

let leader_thread_pool = leader_vm
.new_thread_pool("mock", thread_count)
.await
.unwrap();
let follower_thread_pool = follower_vm
.new_thread_pool("mock", thread_count)
.await
.unwrap();

let leader_config = StreamCipherConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();

let follower_config = StreamCipherConfigBuilder::default()
.id("test".to_string())
.build()
.unwrap();

let mut leader = MpcStreamCipher::<Aes128Ctr, _>::new(leader_config, leader_thread_pool);
leader.set_key(leader_key, leader_iv);

let mut follower = MpcStreamCipher::<Aes128Ctr, _>::new(follower_config, follower_thread_pool);
follower.set_key(follower_key, follower_iv);

let plaintext = vec![0u8; len];
let explicit_nonce = [0u8; 8];
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext);

_ = tokio::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), plaintext),
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
)
.unwrap();

_ = tokio::try_join!(leader_vm.finalize(), follower_vm.finalize()).unwrap();
}

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("stream_cipher/encrypt_private");
let rt = tokio::runtime::Runtime::new().unwrap();

let thread_count = 8;
let len = 128;
let len = 1024;

let mut group = c.benchmark_group("stream_cipher/encrypt_private");
group.throughput(Throughput::Bytes(len as u64));
group.bench_function(format!("{}", len), |b| {
b.to_async(&rt)
.iter(|| async { bench_stream_cipher_encrypt(thread_count, len).await })
});

drop(group);

let mut group = c.benchmark_group("stream_cipher/zk");
group.throughput(Throughput::Bytes(len as u64));
group.bench_function(format!("{}", len), |b| {
b.to_async(&rt)
.iter(|| async { bench_stream_cipher_zk(thread_count, len).await })
});

drop(group);
}

criterion_group!(benches, criterion_benchmark);
Expand Down
34 changes: 34 additions & 0 deletions components/cipher/stream-cipher/src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {

/// Returns the circuit of the cipher
fn circuit() -> Arc<Circuit>;

/// Applies the keystream to the message
fn apply_keystream(
key: &Self::KEY,
iv: &Self::IV,
start_ctr: usize,
explicit_nonce: &Self::NONCE,
msg: &[u8],
) -> Vec<u8>;
}

/// A circuit for AES-128 in counter mode.
Expand All @@ -71,4 +80,29 @@ impl CtrCircuit for Aes128Ctr {
fn circuit() -> Arc<Circuit> {
AES_CTR.clone()
}

fn apply_keystream(
key: &Self::KEY,
iv: &Self::IV,
start_ctr: usize,
explicit_nonce: &Self::NONCE,
msg: &[u8],
) -> Vec<u8> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use ctr::Ctr32BE;

let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(explicit_nonce);
let mut cipher = Ctr32BE::<Aes128>::new(key.into(), &full_iv.into());
let mut buf = msg.to_vec();

cipher
.try_seek(start_ctr * Self::BLOCK_LEN)
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut buf);

buf
}
}
22 changes: 6 additions & 16 deletions components/cipher/stream-cipher/src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,36 @@
use mpz_circuits::{circuits::aes128_trace, once_cell::sync::Lazy, trace, Circuit, CircuitBuilder};
use std::sync::Arc;

/// AES encrypt counter block and apply message.
/// AES encrypt counter block.
///
/// # Inputs
///
/// 0. KEY: 16-byte encryption key
/// 1. IV: 4-byte IV
/// 2. EXPLICIT_NONCE: 8-byte explicit nonce
/// 3. CTR: 4-byte counter
/// 4. MSG: 16-byte message
///
/// # Outputs
///
/// 0. CIPHERTEXT: 16-byte output
/// 0. ECB: 16-byte output
pub(crate) static AES_CTR: Lazy<Arc<Circuit>> = Lazy::new(|| {
let builder = CircuitBuilder::new();
let key = builder.add_array_input::<u8, 16>();
let iv = builder.add_array_input::<u8, 4>();
let nonce = builder.add_array_input::<u8, 8>();
let ctr = builder.add_array_input::<u8, 4>();
let msg = builder.add_array_input::<u8, 16>();
let ciphertext = aes_ctr_trace(builder.state(), key, iv, nonce, ctr, msg);
builder.add_output(ciphertext);
let ecb = aes_ctr_trace(builder.state(), key, iv, nonce, ctr);
builder.add_output(ecb);

Arc::new(builder.build().unwrap())
});

#[trace]
#[dep(aes_128, aes128_trace)]
#[allow(dead_code)]
fn aes_ctr(
key: [u8; 16],
iv: [u8; 4],
explicit_nonce: [u8; 8],
ctr: [u8; 4],
msg: [u8; 16],
) -> [u8; 16] {
fn aes_ctr(key: [u8; 16], iv: [u8; 4], explicit_nonce: [u8; 8], ctr: [u8; 4]) -> [u8; 16] {
let block: Vec<_> = iv.into_iter().chain(explicit_nonce).chain(ctr).collect();
let ectr = aes_128(key, block.try_into().unwrap());

std::array::from_fn(|i| ectr[i] ^ msg[i])
aes_128(key, block.try_into().unwrap())
}

#[allow(dead_code)]
Expand Down
107 changes: 1 addition & 106 deletions components/cipher/stream-cipher/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ pub(crate) struct KeyBlockConfig<C: CtrCircuit> {
pub(crate) iv: ValueRef,
pub(crate) explicit_nonce: C::NONCE,
pub(crate) ctr: u32,
pub(crate) input_text_config: InputText,
pub(crate) output_text_config: OutputTextConfig,
_pd: PhantomData<C>,
}

Expand All @@ -45,28 +43,17 @@ impl<C: CtrCircuit> Debug for KeyBlockConfig<C> {
.field("iv", &self.iv)
.field("explicit_nonce", &self.explicit_nonce)
.field("ctr", &self.ctr)
.field("input_text_config", &self.input_text_config)
.field("output_text_config", &self.output_text_config)
.finish()
}
}

impl<C: CtrCircuit> KeyBlockConfig<C> {
pub(crate) fn new(
key: ValueRef,
iv: ValueRef,
explicit_nonce: C::NONCE,
ctr: u32,
input_text_config: InputText,
output_text_config: OutputTextConfig,
) -> Self {
pub(crate) fn new(key: ValueRef, iv: ValueRef, explicit_nonce: C::NONCE, ctr: u32) -> Self {
Self {
key,
iv,
explicit_nonce,
ctr,
input_text_config,
output_text_config,
_pd: PhantomData,
}
}
Expand Down Expand Up @@ -95,95 +82,3 @@ impl std::fmt::Debug for InputText {
}
}
}

impl InputText {
/// Returns the length of the input text.
#[allow(clippy::len_without_is_empty)]
pub(crate) fn len(&self) -> usize {
match self {
InputText::Public { text, .. } => text.len(),
InputText::Private { text, .. } => text.len(),
InputText::Blind { ids } => ids.len(),
}
}

/// Appends padding bytes to the input text.
pub(crate) fn append_padding(&mut self, append_ids: Vec<String>) {
match self {
InputText::Public { ids, text } => {
ids.extend(append_ids);
text.resize(ids.len(), 0u8);
}
InputText::Private { ids, text } => {
ids.extend(append_ids);
text.resize(ids.len(), 0u8);
}
InputText::Blind { ids } => {
ids.extend(append_ids);
}
};
}

/// Drains the first `n` bytes from the input text.
pub(crate) fn drain(&mut self, n: usize) -> InputText {
match self {
InputText::Public { ids, text } => InputText::Public {
ids: ids.drain(..n).collect(),
text: text.drain(..n).collect(),
},
InputText::Private { ids, text: bytes } => InputText::Private {
ids: ids.drain(..n).collect(),
text: bytes.drain(..n).collect(),
},
InputText::Blind { ids } => InputText::Blind {
ids: ids.drain(..n).collect(),
},
}
}
}

#[derive(Debug)]
pub(crate) enum OutputTextConfig {
Public { ids: Vec<String> },
Private { ids: Vec<String> },
Blind { ids: Vec<String> },
Shared { ids: Vec<String> },
}

impl OutputTextConfig {
/// Appends padding bytes to the output text.
pub(crate) fn append_padding(&mut self, append_ids: Vec<String>) {
match self {
OutputTextConfig::Public { ids } => {
ids.extend(append_ids);
}
OutputTextConfig::Private { ids } => {
ids.extend(append_ids);
}
OutputTextConfig::Blind { ids } => {
ids.extend(append_ids);
}
OutputTextConfig::Shared { ids } => {
ids.extend(append_ids);
}
};
}

/// Drains the first `n` bytes from the output text.
pub(crate) fn drain(&mut self, n: usize) -> OutputTextConfig {
match self {
OutputTextConfig::Public { ids } => OutputTextConfig::Public {
ids: ids.drain(..n).collect(),
},
OutputTextConfig::Private { ids } => OutputTextConfig::Private {
ids: ids.drain(..n).collect(),
},
OutputTextConfig::Blind { ids } => OutputTextConfig::Blind {
ids: ids.drain(..n).collect(),
},
OutputTextConfig::Shared { ids } => OutputTextConfig::Shared {
ids: ids.drain(..n).collect(),
},
}
}
}
Loading

0 comments on commit 786195b

Please sign in to comment.