diff --git a/infrastructure/tari_script/src/op_codes.rs b/infrastructure/tari_script/src/op_codes.rs index 50350e0dbb..2eab0a48bd 100644 --- a/infrastructure/tari_script/src/op_codes.rs +++ b/infrastructure/tari_script/src/op_codes.rs @@ -118,6 +118,7 @@ pub const OP_HASH_BLAKE256: u8 = 0xb0; pub const OP_HASH_SHA256: u8 = 0xb1; pub const OP_HASH_SHA3: u8 = 0xb2; pub const OP_TO_RISTRETTO_POINT: u8 = 0xb3; +pub const OP_CHECK_MULTI_SIG_VERIFY_AGGREGATE_PUB_KEY: u8 = 0xb4; // Opcode constants: Miscellaneous pub const OP_RETURN: u8 = 0x60; @@ -234,6 +235,9 @@ pub enum Opcode { /// Identical to CheckMultiSig, except that nothing is pushed to the stack if the m signatures are valid, and the /// operation fails with VERIFY_FAILED if any of the signatures are invalid. CheckMultiSigVerify(u8, u8, Vec, Box), + /// Pop m signatures from the stack. If m signatures out of the provided n public keys sign the 32-byte message, + /// push the aggregate of the public keys to the stack, otherwise fails with VERIFY_FAILED. + CheckMultiSigVerifyAggregatePubKey(u8, u8, Vec, Box), /// Pops the top element which must be a valid 32-byte scalar or hash and calculates the corresponding Ristretto /// point, and pushes the result to the stack. Fails with EMPTY_STACK if the stack is empty. ToRistrettoPoint, @@ -355,6 +359,10 @@ impl Opcode { let (m, n, keys, msg, end) = Opcode::read_multisig_args(bytes)?; Ok((CheckMultiSigVerify(m, n, keys, msg), &bytes[end..])) }, + OP_CHECK_MULTI_SIG_VERIFY_AGGREGATE_PUB_KEY => { + let (m, n, keys, msg, end) = Opcode::read_multisig_args(bytes)?; + Ok((CheckMultiSigVerifyAggregatePubKey(m, n, keys, msg), &bytes[end..])) + }, OP_TO_RISTRETTO_POINT => Ok((ToRistrettoPoint, &bytes[1..])), OP_RETURN => Ok((Return, &bytes[1..])), OP_IF_THEN => Ok((IfThen, &bytes[1..])), @@ -464,6 +472,13 @@ impl Opcode { } array.extend_from_slice(msg.deref()); }, + CheckMultiSigVerifyAggregatePubKey(m, n, public_keys, msg) => { + array.extend_from_slice(&[OP_CHECK_MULTI_SIG_VERIFY_AGGREGATE_PUB_KEY, *m, *n]); + for public_key in public_keys { + array.extend(public_key.as_bytes()); + } + array.extend_from_slice(msg.deref()); + }, ToRistrettoPoint => array.push(OP_TO_RISTRETTO_POINT), Return => array.push(OP_RETURN), IfThen => array.push(OP_IF_THEN), @@ -530,6 +545,17 @@ impl fmt::Display for Opcode { (*msg).to_hex() ) }, + CheckMultiSigVerifyAggregatePubKey(m, n, public_keys, msg) => { + let keys: Vec = public_keys.iter().map(|p| p.to_hex()).collect(); + write!( + fmt, + "CheckMultiSigVerifyAggregatePubKey({}, {}, [{}], {})", + *m, + *n, + keys.join(", "), + (*msg).to_hex() + ) + }, ToRistrettoPoint => write!(fmt, "ToRistrettoPoint"), Return => write!(fmt, "Return"), IfThen => write!(fmt, "IfThen"), @@ -766,12 +792,20 @@ mod test { 6c9cb4d3e57351462122310fa22c90b1e6dfb528d64615363d1261a75da3e401)", ); test_checkmultisig( - &Opcode::CheckMultiSigVerify(1, 2, keys, Box::new(*msg)), + &Opcode::CheckMultiSigVerify(1, 2, keys.clone(), Box::new(*msg)), OP_CHECK_MULTI_SIG_VERIFY, "CheckMultiSigVerify(1, 2, [9c8bc5f90d221191748e8dd7686f09e1114b4bada4c367ed58ae199c51eb100b, \ 56e9f018b138ba843521b3243a29d81730c3a4c25108b108b1ca47c2132db569], \ 6c9cb4d3e57351462122310fa22c90b1e6dfb528d64615363d1261a75da3e401)", ); + test_checkmultisig( + &Opcode::CheckMultiSigVerifyAggregatePubKey(1, 2, keys, Box::new(*msg)), + OP_CHECK_MULTI_SIG_VERIFY_AGGREGATE_PUB_KEY, + "CheckMultiSigVerifyAggregatePubKey(1, 2, \ + [9c8bc5f90d221191748e8dd7686f09e1114b4bada4c367ed58ae199c51eb100b, \ + 56e9f018b138ba843521b3243a29d81730c3a4c25108b108b1ca47c2132db569], \ + 6c9cb4d3e57351462122310fa22c90b1e6dfb528d64615363d1261a75da3e401)", + ); } #[test] diff --git a/infrastructure/tari_script/src/script.rs b/infrastructure/tari_script/src/script.rs index 38ab6cb8b1..df7d91b945 100644 --- a/infrastructure/tari_script/src/script.rs +++ b/infrastructure/tari_script/src/script.rs @@ -248,19 +248,26 @@ impl TariScript { } }, CheckMultiSig(m, n, public_keys, msg) => { - if self.check_multisig(stack, *m, *n, public_keys, *msg.deref())? { + if self.check_multisig(stack, *m, *n, public_keys, *msg.deref())?.is_some() { stack.push(Number(1)) } else { stack.push(Number(0)) } }, CheckMultiSigVerify(m, n, public_keys, msg) => { - if self.check_multisig(stack, *m, *n, public_keys, *msg.deref())? { + if self.check_multisig(stack, *m, *n, public_keys, *msg.deref())?.is_some() { Ok(()) } else { Err(ScriptError::VerifyFailed) } }, + CheckMultiSigVerifyAggregatePubKey(m, n, public_keys, msg) => { + if let Some(agg_pub_key) = self.check_multisig(stack, *m, *n, public_keys, *msg.deref())? { + stack.push(PublicKey(agg_pub_key)) + } else { + Err(ScriptError::VerifyFailed) + } + }, ToRistrettoPoint => self.handle_to_ristretto_point(stack), Return => Err(ScriptError::Return), IfThen => TariScript::handle_if_then(stack, state), @@ -505,9 +512,9 @@ impl TariScript { n: u8, public_keys: &[RistrettoPublicKey], message: Message, - ) -> Result { - if m == 0 || n == 0 || m > n || n > MAX_MULTISIG_LIMIT { - return Err(ScriptError::InvalidData); + ) -> Result, ScriptError> { + if m == 0 || n == 0 || m > n || n > MAX_MULTISIG_LIMIT || public_keys.len() != n as usize { + return Err(ScriptError::ValueExceedsBounds); } // pop m sigs let m = m as usize; @@ -524,20 +531,25 @@ impl TariScript { #[allow(clippy::mutable_key_type)] let mut sig_set = HashSet::new(); + let mut agg_pub_key = RistrettoPublicKey::default(); for s in &signatures { for (i, pk) in public_keys.iter().enumerate() { if !sig_set.contains(s) && !key_signed[i] && s.verify_challenge(pk, &message) { key_signed[i] = true; sig_set.insert(s); + agg_pub_key = agg_pub_key + pk; break; } } if !sig_set.contains(s) { - return Ok(false); + return Ok(None); } } - - Ok(sig_set.len() == m) + if sig_set.len() == m { + Ok(Some(agg_pub_key)) + } else { + Ok(None) + } } fn handle_to_ristretto_point(&self, stack: &mut ExecutionStack) -> Result<(), ScriptError> { @@ -625,6 +637,7 @@ mod test { inputs, op_codes::{slice_to_boxed_hash, slice_to_boxed_message, HashValue, Message}, ExecutionStack, + Opcode::CheckMultiSigVerifyAggregatePubKey, ScriptContext, StackItem, StackItem::{Commitment, Hash, Number}, @@ -1145,21 +1158,21 @@ mod test { let script = TariScript::new(ops); let inputs = inputs!(s_alice.clone()); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); let keys = vec![p_alice.clone(), p_bob.clone()]; let ops = vec![CheckMultiSig(1, 0, keys, msg.clone())]; let script = TariScript::new(ops); let inputs = inputs!(s_alice.clone()); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); let keys = vec![p_alice, p_bob]; let ops = vec![CheckMultiSig(2, 1, keys, msg)]; let script = TariScript::new(ops); let inputs = inputs!(s_alice); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); // max n is 32 let (msg, data) = multisig_data(33); @@ -1169,7 +1182,7 @@ mod test { let items = sigs.map(StackItem::Signature).collect(); let inputs = ExecutionStack::new(items); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); // 3 of 4 let (msg, data) = multisig_data(4); @@ -1258,7 +1271,7 @@ mod test { // 1 of 3 let keys = vec![p_alice.clone(), p_bob.clone(), p_carol.clone()]; - let ops = vec![CheckMultiSigVerify(1, 2, keys, msg.clone())]; + let ops = vec![CheckMultiSigVerify(1, 3, keys, msg.clone())]; let script = TariScript::new(ops); let inputs = inputs!(Number(1), s_alice.clone()); @@ -1292,6 +1305,31 @@ mod test { let err = script.execute(&inputs).unwrap_err(); assert_eq!(err, ScriptError::VerifyFailed); + // 2 of 3 (returning the aggregate public key of the signatories) + let keys = vec![p_alice.clone(), p_bob.clone(), p_carol.clone()]; + let ops = vec![CheckMultiSigVerifyAggregatePubKey(2, 3, keys, msg.clone())]; + let script = TariScript::new(ops); + + let inputs = inputs!(s_alice.clone(), s_bob.clone()); + let agg_pub_key = script.execute(&inputs).unwrap(); + assert_eq!(agg_pub_key, StackItem::PublicKey(p_alice.clone() + p_bob.clone())); + + let inputs = inputs!(s_alice.clone(), s_carol.clone()); + let agg_pub_key = script.execute(&inputs).unwrap(); + assert_eq!(agg_pub_key, StackItem::PublicKey(p_alice.clone() + p_carol.clone())); + + let inputs = inputs!(s_bob.clone(), s_carol.clone()); + let agg_pub_key = script.execute(&inputs).unwrap(); + assert_eq!(agg_pub_key, StackItem::PublicKey(p_bob.clone() + p_carol.clone())); + + let inputs = inputs!(s_alice.clone(), s_carol.clone(), s_bob.clone()); + let err = script.execute(&inputs).unwrap_err(); + assert_eq!(err, ScriptError::NonUnitLengthStack); + + let inputs = inputs!(p_bob.clone()); + let err = script.execute(&inputs).unwrap_err(); + assert_eq!(err, ScriptError::StackUnderflow); + // 3 of 3 let keys = vec![p_alice.clone(), p_bob.clone(), p_carol]; let ops = vec![CheckMultiSigVerify(3, 3, keys, msg.clone())]; @@ -1313,21 +1351,21 @@ mod test { let script = TariScript::new(ops); let inputs = inputs!(s_alice.clone()); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); let keys = vec![p_alice.clone(), p_bob.clone()]; let ops = vec![CheckMultiSigVerify(1, 0, keys, msg.clone())]; let script = TariScript::new(ops); let inputs = inputs!(s_alice.clone()); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); let keys = vec![p_alice, p_bob]; let ops = vec![CheckMultiSigVerify(2, 1, keys, msg)]; let script = TariScript::new(ops); let inputs = inputs!(s_alice); let err = script.execute(&inputs).unwrap_err(); - assert_eq!(err, ScriptError::InvalidData); + assert_eq!(err, ScriptError::ValueExceedsBounds); // 3 of 4 let (msg, data) = multisig_data(4);