Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve key scanning #6374

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions base_layer/common_types/src/wallet_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ impl Default for WalletType {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LedgerWallet {
account: u64,
pub pubkey: Option<RistrettoPublicKey>,
pub public_alpha: Option<RistrettoPublicKey>,
pub network: Network,
}

impl Display for LedgerWallet {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "account {}", self.account)?;
write!(f, "pubkey {}", self.pubkey.is_some())?;
write!(f, "pubkey {}", self.public_alpha.is_some())?;
Ok(())
}
}
Expand All @@ -81,10 +81,10 @@ impl Display for LedgerWallet {
const WALLET_CLA: u8 = 0x80;

impl LedgerWallet {
pub fn new(account: u64, network: Network, pubkey: Option<RistrettoPublicKey>) -> Self {
pub fn new(account: u64, network: Network, public_alpha: Option<RistrettoPublicKey>) -> Self {
Self {
account,
pubkey,
public_alpha,
network,
}
}
Expand Down
54 changes: 35 additions & 19 deletions base_layer/core/src/transactions/key_manager/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use tari_utilities::{hex::Hex, ByteArray};
use tokio::sync::RwLock;

const LOG_TARGET: &str = "c::bn::key_manager::key_manager_service";
const KEY_MANAGER_MAX_SEARCH_DEPTH: u64 = 1_000_000;
const TRANSACTION_KEY_MANAGER_MAX_SEARCH_DEPTH: u64 = 1_000_000;

use crate::{
common::ConfidentialOutputHasher,
Expand Down Expand Up @@ -242,9 +242,11 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
KeyId::Derived { branch, label, index } => {
let public_alpha = match &self.wallet_type {
WalletType::Software(_k, pk) => pk,
WalletType::Ledger(ledger) => ledger.pubkey.as_ref().ok_or(KeyManagerServiceError::LedgerError(
"Key manager set to use ledger, ledger alpha public key missing".to_string(),
))?,
WalletType::Ledger(ledger) => {
ledger.public_alpha.as_ref().ok_or(KeyManagerServiceError::LedgerError(
"Key manager set to use ledger, ledger alpha public key missing".to_string(),
))?
},
};
let km = self
.key_managers
Expand Down Expand Up @@ -341,11 +343,20 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static

let current_index = km.key_index();

for i in 0u64..current_index + KEY_MANAGER_MAX_SEARCH_DEPTH {
let public_key = PublicKey::from_secret_key(&km.derive_key(i)?.key);
for i in 0u64..TRANSACTION_KEY_MANAGER_MAX_SEARCH_DEPTH {
let index = current_index + i;
let public_key = PublicKey::from_secret_key(&km.derive_key(index)?.key);
if public_key == *key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, i);
return Ok(i);
return Ok(index);
}
if i <= current_index && i != 0u64 {
let index = current_index - i;
let public_key = PublicKey::from_secret_key(&km.derive_key(index)?.key);
if public_key == *key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, index);
return Ok(index);
}
}
}

Expand All @@ -363,11 +374,21 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static

let current_index = km.key_index();

for i in 0u64..current_index + KEY_MANAGER_MAX_SEARCH_DEPTH {
let private_key = &km.derive_key(i)?.key;
// its most likely that the key is close to the current index, so we start searching from the current index
for i in 0u64..TRANSACTION_KEY_MANAGER_MAX_SEARCH_DEPTH {
let index = current_index + i;
let private_key = &km.derive_key(index)?.key;
if private_key == key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, i);
return Ok(i);
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, index);
return Ok(index);
}
if i <= current_index && i != 0u64 {
let index = current_index - i;
let private_key = &km.derive_key(index)?.key;
if private_key == key {
trace!(target: LOG_TARGET, "Key found in {} Key Chain at index {}", branch, index);
return Ok(index);
}
}
}

Expand Down Expand Up @@ -418,7 +439,7 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
},
KeyId::Derived { branch, label, index } => match &self.wallet_type {
WalletType::Ledger(_) => Err(KeyManagerServiceError::LedgerPrivateKeyInaccessible),
WalletType::Software(k, _pk) => {
WalletType::Software(private_alpha, _pk) => {
let km = self
.key_managers
.get(branch)
Expand All @@ -431,7 +452,7 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
let private_key = PrivateKey::from_uniform_bytes(hasher.as_ref()).map_err(|_| {
KeyManagerServiceError::UnknownError(format!("Invalid private key for {}", label))
})?;
let private_key = private_key + k;
let private_key = private_key + private_alpha;
Ok(private_key)
},
},
Expand Down Expand Up @@ -1218,12 +1239,7 @@ where TBackend: KeyManagerBackend<PublicKey> + 'static
self.crypto_factories
.range_proof
.verify_mask(output.commitment(), &private_key, value.into())?;
// Detect the branch we need to scan on for the key.
let branch = if output.is_coinbase() {
TransactionKeyManagerBranch::Coinbase.get_branch_key()
} else {
TransactionKeyManagerBranch::CommitmentMask.get_branch_key()
};
let branch = TransactionKeyManagerBranch::CommitmentMask.get_branch_key();
SWvheerden marked this conversation as resolved.
Show resolved Hide resolved
let key = match self.find_private_key_index(&branch, &private_key).await {
Ok(index) => {
self.update_current_key_index_if_higher(&branch, index).await?;
Expand Down
3 changes: 0 additions & 3 deletions base_layer/core/src/transactions/key_manager/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ pub enum TxoStage {
#[derive(Clone, Copy, EnumIter)]
pub enum TransactionKeyManagerBranch {
DataEncryption = 0x00,
Coinbase = 0x01,
MetadataEphemeralNonce = 0x02,
CommitmentMask = 0x03,
Nonce = 0x04,
Expand All @@ -71,7 +70,6 @@ impl TransactionKeyManagerBranch {
pub fn get_branch_key(self) -> String {
match self {
TransactionKeyManagerBranch::DataEncryption => "data encryption".to_string(),
TransactionKeyManagerBranch::Coinbase => "coinbase".to_string(),
TransactionKeyManagerBranch::CommitmentMask => "commitment mask".to_string(),
TransactionKeyManagerBranch::Nonce => "nonce".to_string(),
TransactionKeyManagerBranch::MetadataEphemeralNonce => "metadata ephemeral nonce".to_string(),
Expand All @@ -83,7 +81,6 @@ impl TransactionKeyManagerBranch {
pub fn from_key(key: &str) -> Self {
match key {
"data encryption" => TransactionKeyManagerBranch::DataEncryption,
"coinbase" => TransactionKeyManagerBranch::Coinbase,
"commitment mask" => TransactionKeyManagerBranch::CommitmentMask,
"metadata ephemeral nonce" => TransactionKeyManagerBranch::MetadataEphemeralNonce,
"kernel nonce" => TransactionKeyManagerBranch::KernelNonce,
Expand Down
2 changes: 1 addition & 1 deletion base_layer/core/src/validation/block_body/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ async fn it_allows_multiple_coinbases() {

let (mut block, coinbase) = blockchain.create_unmined_block(block_spec!("A1", parent: "GB")).await;
let spend_key_id = KeyId::Managed {
branch: TransactionKeyManagerBranch::Coinbase.get_branch_key(),
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
index: 42,
};
let wallet_payment_address = TariAddress::default();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ where
if start.elapsed().as_millis() > 0 {
trace!(
target: LOG_TARGET,
"sqlite profile - insert_imported_key: lock {} + db_op {} = {} ms",
"sqlite profile - get_imported_key: lock {} + db_op {} = {} ms",
acquire_lock.as_millis(),
(start.elapsed() - acquire_lock).as_millis(),
start.elapsed().as_millis()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
use std::time::Instant;

use log::*;
use tari_common_types::{transaction::TxId, types::FixedHash};
use tari_common_types::{
transaction::TxId,
types::{FixedHash, PrivateKey},
};
use tari_core::transactions::{
key_manager::{TariKeyId, TransactionKeyManagerBranch, TransactionKeyManagerInterface, TransactionKeyManagerLabel},
tari_amount::MicroMinotari,
Expand All @@ -35,6 +38,7 @@ use tari_core::transactions::{
WalletOutput,
},
};
use tari_crypto::keys::SecretKey;
use tari_key_manager::key_manager_service::KeyId;
use tari_script::{inputs, script, ExecutionStack, Opcode, TariScript};
use tari_utilities::hex::Hex;
Expand Down Expand Up @@ -156,8 +160,6 @@ where
tx_id,
hash: *hash,
});
self.update_outputs_script_private_key_and_update_key_manager_index(output)
.await?;
trace!(
target: LOG_TARGET,
"Output {} with value {} with {} recovered",
Expand Down Expand Up @@ -200,11 +202,16 @@ where
known_scripts: &[KnownOneSidedPaymentScript],
) -> Result<Option<(ExecutionStack, TariKeyId)>, OutputManagerError> {
let (input_data, script_key) = if script == &script!(Nop) {
// This is a nop, so we can just create a new key an create the input stack.
let key = KeyId::Derived {
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
label: TransactionKeyManagerLabel::ScriptKey.get_branch_key(),
index: spending_key.managed_index().unwrap(),
// This is a nop, so we can just create a new key for the input stack.
SWvheerden marked this conversation as resolved.
Show resolved Hide resolved
let key = if let Some(index) = spending_key.managed_index() {
KeyId::Derived {
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
label: TransactionKeyManagerLabel::ScriptKey.get_branch_key(),
index,
}
} else {
let private_key = PrivateKey::random(&mut rand::thread_rng());
self.master_key_manager.import_key(private_key).await?
};
let public_key = self.master_key_manager.get_public_key_at_key_id(&key).await?;
(inputs!(public_key), key)
Expand Down Expand Up @@ -259,43 +266,4 @@ where

Ok(Some((key, committed_value, payment_id)))
}

/// Find the key manager index that corresponds to the spending key in the rewound output, if found then modify
/// output to contain correct associated script private key and update the key manager to the highest index it has
/// seen so far.
async fn update_outputs_script_private_key_and_update_key_manager_index(
&mut self,
output: &mut WalletOutput,
) -> Result<(), OutputManagerError> {
let public_key = self
.master_key_manager
.get_public_key_at_key_id(&output.spending_key_id)
.await?;
let script_key = {
let found_index = self
.master_key_manager
.find_key_index(
TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
&public_key,
)
.await?;

self.master_key_manager
.update_current_key_index_if_higher(
TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
found_index,
)
.await?;

TariKeyId::Derived {
branch: TransactionKeyManagerBranch::CommitmentMask.get_branch_key(),
label: TransactionKeyManagerLabel::ScriptKey.get_branch_key(),
index: found_index,
}
};
let public_script_key = self.master_key_manager.get_public_key_at_key_id(&script_key).await?;
output.input_data = inputs!(public_script_key);
output.script_key_id = script_key;
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ where
));

let timer = Instant::now();

loop {
let tip_header = self.get_chain_tip_header(&mut client).await?;
let tip_header_hash = tip_header.hash();
Expand Down Expand Up @@ -563,6 +562,7 @@ where
height: u64,
) -> Result<Vec<(WalletOutput, String, ImportStatus, TxId, TransactionOutput)>, UtxoScannerError> {
let mut found_outputs: Vec<(WalletOutput, String, ImportStatus, TxId, TransactionOutput)> = Vec::new();
let start = Instant::now();
found_outputs.append(
&mut self
.resources
Expand All @@ -586,6 +586,8 @@ where
})
.collect::<Result<Vec<_>, _>>()?,
);
let scanned_time = start.elapsed();
let start = Instant::now();

found_outputs.append(
&mut self
Expand Down Expand Up @@ -613,6 +615,13 @@ where
})
.collect::<Result<Vec<_>, _>>()?,
);
let one_sided_time = start.elapsed();
trace!(
target: LOG_TARGET,
"Scanned for outputs: outputs took {} ms , one-sided took {} ms",
scanned_time.as_millis(),
one_sided_time.as_millis(),
);
Ok(found_outputs)
}

Expand Down
Loading