From a6c238b97268c442a62fae67810e6eb18abf8ccf Mon Sep 17 00:00:00 2001 From: Matthew Howard Date: Tue, 2 Apr 2024 16:13:37 +0100 Subject: [PATCH] fix tests and update readme --- tests/test_blscache.py | 351 +++++++---------------------------------- wheel/README.md | 7 +- 2 files changed, 60 insertions(+), 298 deletions(-) diff --git a/tests/test_blscache.py b/tests/test_blscache.py index ef1697101..8e3443f3e 100644 --- a/tests/test_blscache.py +++ b/tests/test_blscache.py @@ -1,294 +1,57 @@ -// This cache is a bit weird because it's trying to account for validating -// mempool signatures versus block signatures. When validating block signatures, -// there's not much point in caching the pairings because we're probably not going -// to see them again unless there's a reorg. However, a spend in the mempool -// is likely to reappear in a block later, so we can save having to do the pairing -// again. So caching is primarily useful after "catch-up" (fast sync?) is done and -// we're monitoring the mempool in real-time. - -extern crate lru; -use crate::aggregate_verify as agg_ver; -use crate::gtelement::GTElement; -use crate::hash_to_g2; -use crate::PublicKey; -use crate::Signature; -use lru::LruCache; -use sha2::{Digest, Sha256}; -use std::borrow::Borrow; -use std::collections::HashMap; -use std::num::NonZeroUsize; - -#[cfg(feature = "py-bindings")] -use pyo3::types::{PyBool, PyInt, PyList}; -#[cfg(feature = "py-bindings")] -use pyo3::{pyclass, pymethods, PyResult}; - -pub type Bytes32 = [u8; 32]; -pub type Bytes48 = [u8; 48]; - -#[cfg_attr(feature = "py-bindings", pyclass(name = "BLSCache"))] -pub struct BLSCache { - cache: LruCache, -} - -impl BLSCache { - pub fn generator(cache_size: Option) -> Self { - let cache: LruCache = - LruCache::new(NonZeroUsize::new(cache_size.unwrap_or(50000)).unwrap()); - Self { cache } - } - - pub fn get_pairings, M: Borrow<[Vec]>>( - &mut self, - pks: &P, - msgs: &M, - force_cache: bool, - ) -> Vec { - let mut pairings: Vec> = vec![]; - let mut missing_count: usize = 0; - - for (pk, msg) in pks.borrow().iter().zip(msgs.borrow().iter()) { - let mut aug_msg = pk.to_vec(); - aug_msg.extend_from_slice(msg.borrow()); // pk + msg - let mut hasher = Sha256::new(); - hasher.update(aug_msg); - let h: Bytes32 = hasher.finalize().into(); - let pairing: Option<>Element> = self.cache.get(&h); - match pairing { - Some(pairing) => { - if !force_cache { - // Heuristic to avoid more expensive sig validation with pairing - // cache when it's empty and cached pairings won't be useful later - // (e.g. while syncing) - missing_count += 1; - if missing_count > pks.borrow().len() / 2 { - return vec![]; - } - } - pairings.push(Some(pairing.clone())); - } - _ => { - pairings.push(None); - } - } - } - - // G1Element.from_bytes can be expensive due to subgroup check, so we avoid recomputing it with this cache - let mut pk_bytes_to_g1: HashMap = HashMap::new(); - let mut ret: Vec = vec![]; - - for (i, pairing) in pairings.iter_mut().enumerate() { - if let Some(pairing) = pairing { - // equivalent to `if pairing is not None` - ret.push(pairing.clone()); - } else { - let mut aug_msg = pks.borrow()[i].to_vec(); - aug_msg.extend_from_slice(&msgs.borrow()[i]); // pk + msg - let aug_hash: Signature = hash_to_g2(&aug_msg); - - let pk_parsed: &mut PublicKey = pk_bytes_to_g1 - .entry(pks.borrow()[i]) - .or_insert_with(|| PublicKey::from_bytes(&pks.borrow()[i]).unwrap()); - - let pairing: GTElement = aug_hash.pair(pk_parsed); - let mut hasher = Sha256::new(); - hasher.update(&aug_msg); - let h: Bytes32 = hasher.finalize().into(); - self.cache.put(h, pairing.clone()); - ret.push(pairing); - } - } - - ret - } - - pub fn aggregate_verify( - &mut self, - pks: &Vec, - msgs: &Vec>, - sig: &Signature, - force_cache: bool, - ) -> bool { - let mut pairings: Vec = self.get_pairings(pks, msgs, force_cache); - if pairings.is_empty() { - let mut data = Vec::<(PublicKey, Vec)>::new(); - for (pk, msg) in pks.iter().zip(msgs.iter()) { - let pk = PublicKey::from_bytes_unchecked(pk).unwrap(); - data.push((pk.clone(), msg.clone())); - } - let res: bool = agg_ver(sig, data); - return res; - } - let pairings_prod = pairings.pop(); // start with the first pairing - match pairings_prod { - Some(mut prod) => { - for p in pairings.iter() { - // loop through rest of list - prod *= &p; - } - prod == sig.pair(&PublicKey::generator()) - } - _ => pairings.is_empty(), - } - } -} - -// Python Functions -#[cfg(feature = "py-bindings")] -#[pymethods] -impl BLSCache { - #[new] - pub fn init() -> Self { - Self::generator(None) - } - - #[staticmethod] - #[pyo3(name = "generator")] - pub fn py_generator(size: Option<&PyInt>) -> Self { - match size { - Some(s) => { - let usize_value: usize = s.extract::().unwrap(); - Self::generator(Some(usize_value)) - } - None => Self::generator(None), - } - } - - #[pyo3(name = "aggregate_verify")] - pub fn py_aggregate_verify( - &mut self, - pks: &PyList, - msgs: &PyList, - sig: &Signature, - force_cache: &PyBool, - ) -> PyResult { - let pks_r: Vec = pks - .iter() - .map(|item| item.extract::()) - .collect::>()?; - let msgs_r: Vec> = msgs - .iter() - .map(|item| item.extract::>()) - .collect::>()?; - let force_cache_bool = force_cache.extract::()?; - Ok(self.aggregate_verify(&pks_r, &msgs_r, sig, force_cache_bool)) - } - - #[pyo3(name = "len")] - pub fn py_len(&self) -> PyResult { - Ok(self.cache.len()) - } -} - -#[cfg(test)] -pub mod tests { - use super::*; - use crate::aggregate; - use crate::sign; - use crate::SecretKey; - - #[test] - pub fn test_instantiation() { - let mut bls_cache: BLSCache = BLSCache::generator(None); - let byte_array: [u8; 32] = [0; 32]; - let sk: SecretKey = SecretKey::from_seed(&byte_array); - let pk: PublicKey = sk.public_key(); - let msg: [u8; 32] = [106; 32]; - let mut aug_msg: Vec = pk.clone().to_bytes().to_vec(); - aug_msg.extend_from_slice(&msg); // pk + msg - let aug_hash = hash_to_g2(&aug_msg); - let pairing = aug_hash.pair(&pk); - let mut hasher = Sha256::new(); - hasher.update(&aug_msg); - let h: Bytes32 = hasher.finalize().into(); - bls_cache.cache.put(h, pairing.clone()); - assert_eq!(*bls_cache.cache.get(&h).unwrap(), pairing); - } - - #[test] - pub fn test_aggregate_verify() { - let mut bls_cache: BLSCache = BLSCache::generator(None); - assert_eq!(bls_cache.cache.len(), 0); - let byte_array: [u8; 32] = [0; 32]; - let sk: SecretKey = SecretKey::from_seed(&byte_array); - let pk: PublicKey = sk.public_key(); - let msg: Vec = [106; 32].to_vec(); - let sig: Signature = sign(&sk, &msg); - let pk_list: Vec<[u8; 48]> = [pk.to_bytes()].to_vec(); - let msg_list: Vec> = [msg].to_vec(); - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, true)); - assert_eq!(bls_cache.cache.len(), 1); - // try again with (pk, msg) cached - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, true)); - assert_eq!(bls_cache.cache.len(), 1); - } - - #[test] - pub fn test_cache() { - let mut bls_cache: BLSCache = BLSCache::generator(None); - assert_eq!(bls_cache.cache.len(), 0); - let byte_array: [u8; 32] = [0; 32]; - let sk: SecretKey = SecretKey::from_seed(&byte_array); - let pk: PublicKey = sk.public_key(); - let msg: Vec = [106; 32].to_vec(); - let sig: Signature = sign(&sk, &msg); - let mut pk_list: Vec<[u8; 48]> = [pk.to_bytes()].to_vec(); - let mut msg_list: Vec> = [msg].to_vec(); - // add first to cache - // try one cached, one not cached - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, false)); - assert_eq!(bls_cache.cache.len(), 1); - let byte_array: [u8; 32] = [1; 32]; - let sk: SecretKey = SecretKey::from_seed(&byte_array); - let pk: PublicKey = sk.public_key(); - let msg: Vec = [107; 32].to_vec(); - let sig = aggregate([sig, sign(&sk, &msg)]); - pk_list.push(pk.to_bytes()); - msg_list.push(msg); - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, false)); - assert_eq!(bls_cache.cache.len(), 2); - // try reusing a pubkey - let pk: PublicKey = sk.public_key(); - let msg: Vec = [108; 32].to_vec(); - let sig = aggregate([sig, sign(&sk, &msg)]); - pk_list.push(pk.to_bytes()); - msg_list.push(msg); - // try with force_cache disabled - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, false)); - assert_eq!(bls_cache.cache.len(), 2); - // now force it to save the pairing - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, true)); - assert_eq!(bls_cache.cache.len(), 3); - } - - #[test] - pub fn test_cache_limit() { - // set cache size to 3 - let mut bls_cache: BLSCache = BLSCache::generator(Some(3)); - assert_eq!(bls_cache.cache.len(), 0); - // create 5 pk/msg combos - for i in 1..=5 { - let byte_array: [u8; 32] = [i as u8; 32]; - let sk: SecretKey = SecretKey::from_seed(&byte_array); - let pk: PublicKey = sk.public_key(); - let msg: Vec = [106; 32].to_vec(); - let sig: Signature = sign(&sk, &msg); - let pk_list: Vec<[u8; 48]> = [pk.to_bytes()].to_vec(); - let msg_list: Vec> = [msg].to_vec(); - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &sig, true)); - } - assert_eq!(bls_cache.cache.len(), 3); - // recreate first key - let byte_array: [u8; 32] = [1; 32]; - let sk: SecretKey = SecretKey::from_seed(&byte_array); - let pk: PublicKey = sk.public_key(); - let msg: Vec = [106; 32].to_vec(); - let mut aug_msg = pk.to_bytes().to_vec(); - aug_msg.extend_from_slice(&msg); // pk + msg - let mut hasher = Sha256::new(); - hasher.update(aug_msg); - let h: Bytes32 = hasher.finalize().into(); - // assert first key has been removed - assert!(bls_cache.cache.get(&h).is_none()); - } -} \ No newline at end of file +from chia_rs import G1Element, PrivateKey, AugSchemeMPL, G2Element, BLSCache +from chia.util.ints import uint64 +import pytest +from typing import List + + +def test_instantiation() -> None: + bls_cache = BLSCache.generator() + assert bls_cache.len() == 0 + assert BLSCache is not None + seed: bytes = bytes.fromhex( + "003206f418c701193458c013120c5906dc12663ad1520c3e596eb6092c14fe16" + ) + + sk: PrivateKey = AugSchemeMPL.key_gen(seed) + pk: G1Element = sk.get_g1() + msg = b"hello" + sig: G2Element = AugSchemeMPL.sign(sk, msg) + pks: List[bytes] = [pk.to_bytes()] + msgs: List[bytes] = [msg] + result = bls_cache.aggregate_verify(pks, msgs, sig, True) + assert result + assert bls_cache.len() == 1 + result = bls_cache.aggregate_verify(pks, msgs, sig, True) + assert result + assert bls_cache.len() == 1 + pks.append(pk.to_bytes()) + + msg = b"world" + msgs.append(msg) + sig: G2Element = AugSchemeMPL.aggregate([sig, AugSchemeMPL.sign(sk, msg)]) + result = bls_cache.aggregate_verify(pks, msgs, sig, True) + assert result + assert bls_cache.len() == 2 + + +def test_cache_limit() -> None: + bls_cache = BLSCache.generator(3) + assert bls_cache.len() == 0 + assert BLSCache is not None + seed: bytes = bytes.fromhex( + "003206f418c701193458c013120c5906dc12663ad1520c3e596eb6092c14fe16" + ) + + sk: PrivateKey = AugSchemeMPL.key_gen(seed) + pk: G1Element = sk.get_g1() + pks: List[bytes] = [] + msgs: List[bytes] = [] + pk_bytes = pk.to_bytes() + sigs: List[G2Element] = [] + for i in [1, 2, 3, 4]: + msgs.append(i.to_bytes()) + pks.append(pk_bytes) + sigs.append(AugSchemeMPL.sign(sk, i.to_bytes())) + result = bls_cache.aggregate_verify(pks, msgs, AugSchemeMPL.aggregate(sigs), True) + assert result + assert bls_cache.len() == 3 \ No newline at end of file diff --git a/wheel/README.md b/wheel/README.md index fb79d0562..b8e9f0b96 100644 --- a/wheel/README.md +++ b/wheel/README.md @@ -3,11 +3,10 @@ The `chia_rs` wheel contains python bindings for code from the `chia` crate. To run the tests: ``` cd wheel -python3 -m venv venv +pytho -m venv venv . ./venv/bin/activate -pip install -r requirements.txt +python -m pip install -r requirements.txt maturin develop -cd .. -pytest tests +python -m pytest ../tests ```