From e838e2a37715cc736fff160a260c703e2a8213c3 Mon Sep 17 00:00:00 2001 From: yse Date: Sun, 24 Nov 2024 01:16:26 +0100 Subject: [PATCH] feat: add `last_derivation_index` to sync service --- lib/core/src/persist/cache.rs | 27 +++++++-- lib/core/src/persist/mod.rs | 2 +- lib/core/src/persist/sync.rs | 42 ++++++++++++- lib/core/src/sync/mod.rs | 103 +++++++++++++++++++++++++++++++- lib/core/src/sync/model/data.rs | 10 ++++ lib/core/src/sync/model/mod.rs | 4 ++ 6 files changed, 178 insertions(+), 10 deletions(-) diff --git a/lib/core/src/persist/cache.rs b/lib/core/src/persist/cache.rs index 88e9ca32..064ca2a7 100644 --- a/lib/core/src/persist/cache.rs +++ b/lib/core/src/persist/cache.rs @@ -2,13 +2,14 @@ use anyhow::Result; use rusqlite::{Transaction, TransactionBehavior}; use std::str::FromStr; +use crate::sync::model::{data::LAST_DERIVATION_INDEX_DATA_ID, RecordType}; + use super::Persister; const KEY_SWAPPER_PROXY_URL: &str = "swapper_proxy_url"; const KEY_IS_FIRST_SYNC_COMPLETE: &str = "is_first_sync_complete"; const KEY_WEBHOOK_URL: &str = "webhook_url"; -// TODO: The `last_derivation_index` needs to be synced -const KEY_LAST_DERIVATION_INDEX: &str = "last_derivation_index"; +pub(crate) const KEY_LAST_DERIVATION_INDEX: &str = "last_derivation_index"; impl Persister { fn get_cached_item_inner(tx: &Transaction, key: &str) -> Result> { @@ -20,7 +21,11 @@ impl Persister { Ok(res.ok()) } - fn update_cached_item_inner(tx: &Transaction, key: &str, value: String) -> Result<()> { + pub(crate) fn update_cached_item_inner( + tx: &Transaction, + key: &str, + value: String, + ) -> Result<()> { tx.execute( "INSERT OR REPLACE INTO cached_items (key, value) VALUES (?1,?2)", (key, value), @@ -92,7 +97,21 @@ impl Persister { } pub fn set_last_derivation_index(&self, index: u32) -> Result<()> { - self.update_cached_item(KEY_LAST_DERIVATION_INDEX, index.to_string()) + let mut con = self.get_connection()?; + let tx = con.transaction_with_behavior(TransactionBehavior::Immediate)?; + + Self::update_cached_item_inner(&tx, KEY_LAST_DERIVATION_INDEX, index.to_string())?; + self.commit_outgoing( + &tx, + LAST_DERIVATION_INDEX_DATA_ID, + RecordType::LastDerivationIndex, + // insert a mock updated field so that merging with incoming data works as expected + Some(vec!["last-derivation-index".to_string()]), + )?; + tx.commit()?; + self.sync_trigger.try_send(())?; + + Ok(()) } pub fn get_last_derivation_index(&self) -> Result> { diff --git a/lib/core/src/persist/mod.rs b/lib/core/src/persist/mod.rs index 7e99c792..01113b9c 100644 --- a/lib/core/src/persist/mod.rs +++ b/lib/core/src/persist/mod.rs @@ -1,6 +1,6 @@ mod address; mod backup; -mod cache; +pub(crate) mod cache; pub(crate) mod chain; mod migrations; pub(crate) mod receive; diff --git a/lib/core/src/persist/sync.rs b/lib/core/src/persist/sync.rs index cd359f56..7b34e70b 100644 --- a/lib/core/src/persist/sync.rs +++ b/lib/core/src/persist/sync.rs @@ -5,10 +5,10 @@ use rusqlite::{ named_params, Connection, OptionalExtension, Row, Statement, Transaction, TransactionBehavior, }; -use super::{PaymentState, Persister}; +use super::{cache::KEY_LAST_DERIVATION_INDEX, PaymentState, Persister}; use crate::{ sync::model::{ - data::{ChainSyncData, ReceiveSyncData, SendSyncData}, + data::{ChainSyncData, ReceiveSyncData, SendSyncData, LAST_DERIVATION_INDEX_DATA_ID}, sync::Record, RecordType, SyncOutgoingChanges, SyncSettings, SyncState, }, @@ -691,4 +691,42 @@ impl Persister { Ok(()) } + + pub(crate) fn commit_incoming_address_index( + &self, + new_address_index: u32, + sync_state: SyncState, + last_commit_time: Option, + ) -> Result<()> { + let mut con = self.get_connection()?; + let tx = con.transaction_with_behavior(TransactionBehavior::Immediate)?; + + if let Some(last_commit_time) = last_commit_time { + Self::check_commit_update( + &tx, + &Record::get_id_from_record_type( + RecordType::LastDerivationIndex, + LAST_DERIVATION_INDEX_DATA_ID, + ), + last_commit_time, + )?; + } + + Self::update_cached_item_inner( + &tx, + KEY_LAST_DERIVATION_INDEX, + new_address_index.to_string(), + )?; + + Self::set_sync_state_stmt(&tx)?.execute(named_params! { + ":data_id": &sync_state.data_id, + ":record_id": &sync_state.record_id, + ":record_revision": &sync_state.record_revision, + ":is_local": &sync_state.is_local, + })?; + + tx.commit()?; + + Ok(()) + } } diff --git a/lib/core/src/sync/mod.rs b/lib/core/src/sync/mod.rs index f1e74acd..7ac51406 100644 --- a/lib/core/src/sync/mod.rs +++ b/lib/core/src/sync/mod.rs @@ -9,7 +9,10 @@ use tokio::sync::{watch, Mutex}; use crate::sync::model::sync::{Record, SetRecordRequest, SetRecordStatus}; use crate::utils; -use crate::{persist::Persister, prelude::Signer}; +use crate::{ + persist::{cache::KEY_LAST_DERIVATION_INDEX, Persister}, + prelude::Signer, +}; use self::client::SyncerClient; use self::model::{ @@ -124,6 +127,9 @@ impl SyncService { is_update, last_commit_time, )?, + SyncData::LastDerivationIndex(new_address_index) => self + .persister + .commit_incoming_address_index(new_address_index, sync_state, last_commit_time)?, } Ok(()) } @@ -154,6 +160,12 @@ impl SyncService { .into(); SyncData::Chain(chain_data) } + RecordType::LastDerivationIndex => SyncData::LastDerivationIndex( + self.persister + .get_cached_item(KEY_LAST_DERIVATION_INDEX)? + .ok_or(anyhow!("Could not find last derivation index"))? + .parse()?, + ), }; Ok(data) } @@ -327,9 +339,9 @@ mod tests { use std::{collections::HashMap, sync::Arc}; use crate::{ - persist::Persister, + persist::{cache::KEY_LAST_DERIVATION_INDEX, Persister}, prelude::{Direction, PaymentState, Signer}, - sync::model::SyncState, + sync::model::{data::LAST_DERIVATION_INDEX_DATA_ID, SyncState}, test_utils::{ chain_swap::new_chain_swap, persist::{create_persister, new_receive_swap, new_send_swap}, @@ -625,4 +637,89 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_last_derivation_index_update() -> Result<()> { + create_persister!(persister); + let signer: Arc> = Arc::new(Box::new(MockSigner::new())); + + let (incoming_tx, outgoing_records, sync_service) = + new_sync_service(persister.clone(), signer.clone())?; + + // Check pull + assert_eq!(persister.get_cached_item(KEY_LAST_DERIVATION_INDEX)?, None); + + let new_last_derivation_index = 10; + let data = SyncData::LastDerivationIndex(new_last_derivation_index); + incoming_tx + .send(Record::new(data, 0, signer.clone())?) + .await?; + + sync_service.pull().await?; + + assert_eq!( + persister.get_cached_item(KEY_LAST_DERIVATION_INDEX)?, + Some(new_last_derivation_index.to_string()) + ); + + // Check push + let new_last_derivation_index = 20; + persister.set_last_derivation_index(new_last_derivation_index)?; + + sync_service.push().await?; + + let outgoing = outgoing_records.lock().await; + let record = get_outgoing_record( + persister.clone(), + &outgoing, + LAST_DERIVATION_INDEX_DATA_ID, + RecordType::LastDerivationIndex, + )?; + let decrypted_record = record.clone().decrypt(signer.clone())?; + match decrypted_record.data { + SyncData::LastDerivationIndex(last_derivation_index) => { + assert_eq!(last_derivation_index, new_last_derivation_index); + } + _ => { + return Err(anyhow::anyhow!("Unexpected sync data type received.")); + } + } + + // Check pull with merge + let new_local_last_derivation_index = 30; + persister.set_last_derivation_index(new_local_last_derivation_index)?; + + let new_remote_last_derivation_index = 25; + let data = SyncData::LastDerivationIndex(new_remote_last_derivation_index); + incoming_tx + .send(Record::new(data, 0, signer.clone())?) + .await?; + + sync_service.pull().await?; + + // Newer one is persisted (local > remote) + assert_eq!( + persister.get_cached_item(KEY_LAST_DERIVATION_INDEX)?, + Some(new_local_last_derivation_index.to_string()) + ); + + let new_local_last_derivation_index = 35; + persister.set_last_derivation_index(new_local_last_derivation_index)?; + + let new_remote_last_derivation_index = 40; + let data = SyncData::LastDerivationIndex(new_remote_last_derivation_index); + incoming_tx + .send(Record::new(data, 2, signer.clone())?) + .await?; + + sync_service.pull().await?; + + // Newer one is persisted (remote > local) + assert_eq!( + persister.get_cached_item(KEY_LAST_DERIVATION_INDEX)?, + Some(new_remote_last_derivation_index.to_string()) + ); + + Ok(()) + } } diff --git a/lib/core/src/sync/model/data.rs b/lib/core/src/sync/model/data.rs index 24c34ffc..35f29404 100644 --- a/lib/core/src/sync/model/data.rs +++ b/lib/core/src/sync/model/data.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use crate::prelude::{ChainSwap, Direction, ReceiveSwap, SendSwap}; +pub(crate) const LAST_DERIVATION_INDEX_DATA_ID: &str = "last-derivation-index"; + #[derive(Serialize, Deserialize, Clone, Debug)] pub(crate) struct ChainSyncData { pub(crate) swap_id: String, @@ -155,6 +157,7 @@ pub(crate) enum SyncData { Chain(ChainSyncData), Send(SendSyncData), Receive(ReceiveSyncData), + LastDerivationIndex(u32), } impl SyncData { @@ -163,6 +166,7 @@ impl SyncData { SyncData::Chain(chain_data) => &chain_data.swap_id, SyncData::Send(send_data) => &send_data.swap_id, SyncData::Receive(receive_data) => &receive_data.swap_id, + SyncData::LastDerivationIndex(_) => LAST_DERIVATION_INDEX_DATA_ID, } } @@ -181,6 +185,12 @@ impl SyncData { (SyncData::Receive(ref mut base), SyncData::Receive(other)) => { base.merge(other, updated_fields) } + ( + SyncData::LastDerivationIndex(our_index), + SyncData::LastDerivationIndex(their_index), + ) => { + *our_index = std::cmp::max(*their_index, *our_index); + } _ => return Err(anyhow::anyhow!("Cannot merge data from two separate types")), }; Ok(()) diff --git a/lib/core/src/sync/model/mod.rs b/lib/core/src/sync/model/mod.rs index 6b63486a..970849d7 100644 --- a/lib/core/src/sync/model/mod.rs +++ b/lib/core/src/sync/model/mod.rs @@ -26,6 +26,7 @@ pub(crate) enum RecordType { Receive = 0, Send = 1, Chain = 2, + LastDerivationIndex = 3, } impl ToSql for RecordType { @@ -41,6 +42,7 @@ impl FromSql for RecordType { 0 => Ok(Self::Receive), 1 => Ok(Self::Send), 2 => Ok(Self::Chain), + 3 => Ok(Self::LastDerivationIndex), _ => Err(FromSqlError::OutOfRange(i)), }, _ => Err(FromSqlError::InvalidType), @@ -105,6 +107,7 @@ impl Record { SyncData::Chain(_) => "chain-swap", SyncData::Send(_) => "send-swap", SyncData::Receive(_) => "receive-swap", + SyncData::LastDerivationIndex(_) => "wallet-address", } .to_string(); Self::id(prefix, data.id()) @@ -115,6 +118,7 @@ impl Record { RecordType::Chain => "chain-swap", RecordType::Send => "send-swap", RecordType::Receive => "receive-swap", + RecordType::LastDerivationIndex => "wallet-address", } .to_string(); Self::id(prefix, data_id)