diff --git a/proto/anki/ankidroid.proto b/proto/anki/ankidroid.proto new file mode 100644 index 00000000000..19e819475c4 --- /dev/null +++ b/proto/anki/ankidroid.proto @@ -0,0 +1,71 @@ +syntax = "proto3"; + +option java_multiple_files = true; + +import "anki/generic.proto"; +import "anki/scheduler.proto"; + +package anki.ankidroid; + +service AnkidroidService { + rpc SchedTimingTodayLegacy (SchedTimingTodayLegacyRequest) returns (scheduler.SchedTimingTodayResponse); + rpc LocalMinutesWestLegacy (generic.Int64) returns (generic.Int32); + rpc RunDbCommand(generic.Json) returns (generic.Json); + rpc RunDbCommandProto(generic.Json) returns (DBResponse); + rpc InsertForId(generic.Json) returns (generic.Int64); + rpc RunDbCommandForRowCount(generic.Json) returns (generic.Int64); + rpc FlushAllQueries(generic.Empty) returns (generic.Empty); + rpc FlushQuery(generic.Int32) returns (generic.Empty); + rpc GetNextResultPage(GetNextResultPageRequest) returns (DBResponse); + rpc SetPageSize(generic.Int64) returns (generic.Empty); + rpc GetColumnNamesFromQuery(generic.String) returns (generic.StringList); + rpc GetActiveSequenceNumbers(generic.Empty) returns (GetActiveSequenceNumbersResponse); + rpc DebugProduceError(generic.String) returns (generic.Empty); +} + +message DebugActiveDatabaseSequenceNumbersResponse { + repeated int32 sequence_numbers = 1; +} + +message SchedTimingTodayLegacyRequest { + int64 created_secs = 1; + optional sint32 created_mins_west = 2; + int64 now_secs = 3; + sint32 now_mins_west = 4; + sint32 rollover_hour = 5; +} + +// We expect in Java: Null, String, Short, Int, Long, Float, Double, Boolean, Blob (unused) +// We get: DbResult (Null, String, i64, f64, Vec), which matches SQLite documentation +message SqlValue { + oneof Data { + string stringValue = 1; + int64 longValue = 2; + double doubleValue = 3; + bytes blobValue = 4; + } +} + +message Row { + repeated SqlValue fields = 1; +} + +message DbResult { + repeated Row rows = 1; +} + +message DBResponse { + DbResult result = 1; + int32 sequenceNumber = 2; + int32 rowCount = 3; + int64 startIndex = 4; +} + +message GetNextResultPageRequest { + int32 sequence = 1; + int64 index = 2; +} + +message GetActiveSequenceNumbersResponse { + repeated int32 numbers = 1; +} \ No newline at end of file diff --git a/proto/anki/backend.proto b/proto/anki/backend.proto index a1c19e61b45..10829d5ae2d 100644 --- a/proto/anki/backend.proto +++ b/proto/anki/backend.proto @@ -30,6 +30,8 @@ enum ServiceIndex { SERVICE_INDEX_CARDS = 14; SERVICE_INDEX_LINKS = 15; SERVICE_INDEX_IMPORT_EXPORT = 16; + + SERVICE_INDEX_ANKIDROID = 99; } message BackendInit { @@ -64,6 +66,7 @@ message BackendError { IMPORT_ERROR = 16; DELETED = 17; CARD_TYPE_ERROR = 18; + FATAL_ERROR = 99; } // localized error description suitable for displaying to the user diff --git a/proto/anki/collection.proto b/proto/anki/collection.proto index daddc67c218..565652aacdd 100644 --- a/proto/anki/collection.proto +++ b/proto/anki/collection.proto @@ -35,6 +35,8 @@ message OpenCollectionRequest { string media_folder_path = 2; string media_db_path = 3; string log_path = 4; + + bool force_schema11 = 99; } message CloseCollectionRequest { diff --git a/rslib/.gitignore b/rslib/.gitignore index 51613a3d089..e6f5e5ecb50 100644 --- a/rslib/.gitignore +++ b/rslib/.gitignore @@ -1,3 +1,4 @@ Cargo.lock .build +.idea/ target diff --git a/rslib/Cargo.toml b/rslib/Cargo.toml index 8d7b1ecc10e..8cf13576411 100644 --- a/rslib/Cargo.toml +++ b/rslib/Cargo.toml @@ -67,7 +67,6 @@ reqwest = { git="https://github.com/ankitects/reqwest.git", rev="7591444614de02b "stream", "multipart", # the Bazel build scripts separate these out by platform - "native-tls", "rustls-tls", "rustls-tls-webpki-roots", "rustls-tls-native-roots", diff --git a/rslib/src/backend/ankidroid/db.rs b/rslib/src/backend/ankidroid/db.rs new file mode 100644 index 00000000000..57d0d10ca39 --- /dev/null +++ b/rslib/src/backend/ankidroid/db.rs @@ -0,0 +1,461 @@ +use std::{ + collections::HashMap, + mem::size_of, + sync::{ + atomic::{AtomicI32, Ordering}, + Mutex, + }, +}; + +// TODO: storing the results in a box in the cache is unnecessary and more fragile +use i64 as dbresponse_pointer; +use itertools::{ + FoldWhile, + FoldWhile::{Continue, Done}, + Itertools, +}; +use lazy_static::lazy_static; +use rusqlite::ToSql; +use serde_derive::Deserialize; + +use crate::{ + collection::Collection, + error::Result, + pb::{sql_value::Data, DbResponse, DbResult, Row, SqlValue}, +}; + +#[derive(Deserialize)] +struct DBArgs { + sql: String, + args: Vec, +} + +pub trait Sizable { + /** Estimates the heap size of the value, in bytes */ + fn estimate_size(&self) -> usize; +} + +impl Sizable for Data { + fn estimate_size(&self) -> usize { + match self { + Data::StringValue(s) => s.len(), + Data::LongValue(_) => size_of::(), + Data::DoubleValue(_) => size_of::(), + Data::BlobValue(b) => b.len(), + } + } +} + +impl Sizable for SqlValue { + fn estimate_size(&self) -> usize { + // Add a byte for the optional + self.data + .as_ref() + .map(|f| f.estimate_size() + 1) + .unwrap_or(1) + } +} + +impl Sizable for Row { + fn estimate_size(&self) -> usize { + self.fields.iter().map(|x| x.estimate_size()).sum() + } +} + +impl Sizable for DbResult { + fn estimate_size(&self) -> usize { + // Performance: It might be best to take the first x rows and determine the data types + // If we have floats or longs, they'll be a fixed size (excluding nulls) and should speed + // up the calculation as we'll only calculate a subset of the columns. + self.rows.iter().map(|x| x.estimate_size()).sum() + } +} + +pub(crate) fn select_next_slice<'a>(rows: impl Iterator) -> Vec { + select_slice_of_size(rows, get_max_page_size()) + .into_inner() + .1 +} + +fn select_slice_of_size<'a>( + mut rows: impl Iterator, + max_size: usize, +) -> FoldWhile<(usize, Vec)> { + let init: Vec = Vec::new(); + rows.fold_while((0, init), |mut acc, x| { + let new_size = acc.0 + x.estimate_size(); + // If the accumulator is 0, but we're over the size: return a single result so we don't loop forever. + // Theoretically, this shouldn't happen as data should be reasonably sized + if new_size > max_size && acc.0 > 0 { + Done(acc) + } else { + // PERF: should be faster to return (size, numElements) then bulk copy/slice + acc.1.push(x.to_owned()); + Continue((new_size, acc.1)) + } + }) +} + +lazy_static! { + // i64 => Map + static ref HASHMAP: Mutex>> = { + Mutex::new(HashMap::new()) + }; +} + +pub(crate) fn flush_cache(ptr: i64, sequence_number: i32) { + let mut map = HASHMAP.lock().unwrap(); + let entries = map.get_mut(&ptr); + match entries { + Some(seq_to_ptr) => { + let entry = seq_to_ptr.remove_entry(&sequence_number); + match entry { + Some(ptr) => { + let raw = ptr.1 as *mut DbResponse; + unsafe { + Box::from_raw(raw); + } + } + None => {} + } + } + None => {} + } +} + +pub(crate) fn flush_all(ptr: i64) { + let mut map = HASHMAP.lock().unwrap(); + + // clear the map + let entries = map.remove_entry(&ptr); + + match entries { + Some(seq_to_ptr_map) => { + // then clear each value + for val in seq_to_ptr_map.1.values() { + let raw = (*val) as *mut DbResponse; + unsafe { + Box::from_raw(raw); + } + } + } + None => {} + } +} + +pub(crate) fn active_sequences(ptr: i64) -> Vec { + let mut map = HASHMAP.lock().unwrap(); + + match map.get_mut(&ptr) { + Some(x) => { + let keys = x.keys(); + keys.into_iter().copied().collect_vec() + } + None => Vec::new(), + } +} + +/** +Store the data in the cache if larger than than the page size.
+Returns: The data capped to the page size +*/ +pub(crate) fn trim_and_cache_remaining( + backend_ptr: i64, + values: DbResult, + sequence_number: i32, +) -> DbResponse { + let start_index = 0; + + // PERF: Could speed this up by not creating the vector and just calculating the count + let first_result = select_next_slice(values.rows.iter()); + + let row_count = values.rows.len() as i32; + if first_result.len() < values.rows.len() { + let to_store = DbResponse { + result: Some(values), + sequence_number, + row_count, + start_index, + }; + insert_cache(backend_ptr, to_store); + + DbResponse { + result: Some(DbResult { rows: first_result }), + sequence_number, + row_count, + start_index, + } + } else { + DbResponse { + result: Some(values), + sequence_number, + row_count, + start_index, + } + } +} + +fn insert_cache(ptr: i64, result: DbResponse) { + let mut map = HASHMAP.lock().unwrap(); + + match map.get_mut(&ptr) { + Some(_) => {} + None => { + let map2: HashMap = HashMap::new(); + map.insert(ptr, map2); + } + }; + + let out_hash_map = map.get_mut(&ptr).unwrap(); + + out_hash_map.insert( + result.sequence_number, + Box::into_raw(Box::new(result)) as dbresponse_pointer, + ); +} + +pub(crate) fn get_next(ptr: i64, sequence_number: i32, start_index: i64) -> Option { + let result = get_next_result(ptr, &sequence_number, start_index); + + match result.as_ref() { + Some(x) => { + if x.result.is_none() || x.result.as_ref().unwrap().rows.is_empty() { + flush_cache(ptr, sequence_number) + } + } + None => {} + } + + result +} + +fn get_next_result(ptr: i64, sequence_number: &i32, start_index: i64) -> Option { + let map = HASHMAP.lock().unwrap(); + + let result_map = map.get(&ptr)?; + + let backend_ptr = *result_map.get(sequence_number)?; + + let current_result = unsafe { &mut *(backend_ptr as *mut DbResponse) }; + + // TODO: This shouldn't need to exist + let tmp: Vec = Vec::new(); + let next_rows = current_result + .result + .as_ref() + .map(|x| x.rows.iter()) + .unwrap_or_else(|| tmp.iter()); + + let skipped_rows = next_rows.clone().skip(start_index as usize).collect_vec(); + println!("{}", skipped_rows.len()); + + let filtered_rows = select_next_slice(next_rows.skip(start_index as usize)); + + let result = DbResult { + rows: filtered_rows, + }; + + let trimmed_result = DbResponse { + result: Some(result), + sequence_number: current_result.sequence_number, + row_count: current_result.row_count, + start_index, + }; + + Some(trimmed_result) +} + +static SEQUENCE_NUMBER: AtomicI32 = AtomicI32::new(0); + +pub(crate) fn next_sequence_number() -> i32 { + SEQUENCE_NUMBER.fetch_add(1, Ordering::SeqCst) +} + +lazy_static! { + // same as we get from io.requery.android.database.CursorWindow.sCursorWindowSize + static ref DB_COMMAND_PAGE_SIZE: Mutex = Mutex::new(1024 * 1024 * 2); +} + +pub(crate) fn set_max_page_size(size: usize) { + let mut state = DB_COMMAND_PAGE_SIZE.lock().expect("Could not lock mutex"); + *state = size; +} + +fn get_max_page_size() -> usize { + *DB_COMMAND_PAGE_SIZE.lock().unwrap() +} + +fn get_args(in_bytes: &[u8]) -> Result { + let ret: DBArgs = serde_json::from_slice(in_bytes)?; + Ok(ret) +} + +pub(crate) fn insert_for_id(col: &Collection, json: &[u8]) -> Result { + let req = get_args(json)?; + let args: Vec<_> = req.args.iter().map(|a| a as &dyn ToSql).collect(); + col.storage.db.execute(&req.sql, &args[..])?; + Ok(col.storage.db.last_insert_rowid()) +} + +pub(crate) fn execute_for_row_count(col: &Collection, req: &[u8]) -> Result { + let req = get_args(req)?; + let args: Vec<_> = req.args.iter().map(|a| a as &dyn ToSql).collect(); + let count = col.storage.db.execute(&req.sql, &args[..])?; + Ok(count as i64) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + backend::ankidroid::db::{select_slice_of_size, Sizable}, + pb::{sql_value, Row, SqlValue}, + }; + + fn gen_data() -> Vec { + vec![ + SqlValue { + data: Some(sql_value::Data::DoubleValue(12.0)), + }, + SqlValue { + data: Some(sql_value::Data::LongValue(12)), + }, + SqlValue { + data: Some(sql_value::Data::StringValue( + "Hellooooooo World".to_string(), + )), + }, + SqlValue { + data: Some(sql_value::Data::BlobValue(vec![])), + }, + ] + } + + #[test] + fn test_size_estimate() { + let row = Row { fields: gen_data() }; + let result = DbResult { + rows: vec![row.clone(), row.clone()], + }; + + let actual_size = result.estimate_size(); + + let expected_size = (17 + 8 + 8) * 2; // 1 variable string, 1 long, 1 float + let expected_overhead = (4 * 1) * 2; // 4 optional columns + + assert_eq!(actual_size, expected_overhead + expected_size); + } + + #[test] + fn test_stream_size() { + let row = Row { fields: gen_data() }; + let result = DbResult { + rows: vec![row.clone(), row.clone(), row.clone()], + }; + let limit = 74 + 1; // two rows are 74 + + let result = select_slice_of_size(result.rows.iter(), limit).into_inner(); + + assert_eq!( + 2, + result.1.len(), + "The final element should not be included" + ); + assert_eq!( + 74, result.0, + "The size should be the size of the first two objects" + ); + } + + #[test] + fn test_stream_size_too_small() { + let row = Row { fields: gen_data() }; + let result = DbResult { + rows: vec![row.clone()], + }; + let limit = 1; + + let result = select_slice_of_size(result.rows.iter(), limit).into_inner(); + + assert_eq!( + 1, + result.1.len(), + "If the limit is too small, a result is still returned" + ); + assert_eq!( + 37, result.0, + "The size should be the size of the first objects" + ); + } + + const BACKEND_PTR: i64 = 12; + const SEQUENCE_NUMBER: i32 = 1; + + fn get(index: i64) -> Option { + return get_next(BACKEND_PTR, SEQUENCE_NUMBER, index); + } + + fn get_first(result: DbResult) -> DbResponse { + trim_and_cache_remaining(BACKEND_PTR, result, SEQUENCE_NUMBER) + } + + fn seq_number_used() -> bool { + HASHMAP + .lock() + .unwrap() + .get(&BACKEND_PTR) + .unwrap() + .contains_key(&SEQUENCE_NUMBER) + } + + #[test] + fn integration_test() { + let row = Row { fields: gen_data() }; + + // return one row at a time + set_max_page_size(row.estimate_size() - 1); + + let db_query_result = DbResult { + rows: vec![row.clone(), row.clone()], + }; + + let first_jni_response = get_first(db_query_result); + + assert_eq!( + row_count(&first_jni_response), + 1, + "The first call should only return one row" + ); + + let next_index = first_jni_response.start_index + row_count(&first_jni_response); + + let second_response = get(next_index); + + assert!( + second_response.is_some(), + "The second response should return a value" + ); + let valid_second_response = second_response.unwrap(); + assert_eq!(row_count(&valid_second_response), 1); + + let final_index = valid_second_response.start_index + row_count(&valid_second_response); + + assert!(seq_number_used(), "The sequence number is assigned"); + + let final_response = get(final_index); + assert!( + final_response.is_some(), + "The third call should return something with no rows" + ); + assert_eq!( + row_count(&final_response.unwrap()), + 0, + "The third call should return something with no rows" + ); + assert!(!seq_number_used(), "Sequence number data has been cleared"); + } + + fn row_count(resp: &DbResponse) -> i64 { + resp.result.as_ref().map(|x| x.rows.len()).unwrap_or(0) as i64 + } +} diff --git a/rslib/src/backend/ankidroid/error.rs b/rslib/src/backend/ankidroid/error.rs new file mode 100644 index 00000000000..4114df65194 --- /dev/null +++ b/rslib/src/backend/ankidroid/error.rs @@ -0,0 +1,94 @@ +use crate::{ + error::{ + DbError, DbErrorKind as DB, FilteredDeckError, NetworkError, NetworkErrorKind as Net, + SearchErrorKind, SyncError, SyncErrorKind as Sync, + }, + prelude::AnkiError, +}; + +pub(super) fn debug_produce_error(s: &str) -> AnkiError { + let error_value = "error_value".to_string(); + let err = match s { + "InvalidInput" => AnkiError::InvalidInput(error_value), + "TemplateError" => AnkiError::TemplateError(error_value), + "IoError" => AnkiError::IoError(error_value), + "DbErrorFileTooNew" => AnkiError::DbError(DbError { + info: error_value, + kind: DB::FileTooNew, + }), + "DbErrorFileTooOld" => AnkiError::DbError(DbError { + info: error_value, + kind: DB::FileTooOld, + }), + "DbErrorMissingEntity" => AnkiError::DbError(DbError { + info: error_value, + kind: DB::MissingEntity, + }), + "DbErrorCorrupt" => AnkiError::DbError(DbError { + info: error_value, + kind: DB::Corrupt, + }), + "DbErrorLocked" => AnkiError::DbError(DbError { + info: error_value, + kind: DB::Locked, + }), + "DbErrorOther" => AnkiError::DbError(DbError { + info: error_value, + kind: DB::Other, + }), + "NetworkError" => AnkiError::NetworkError(NetworkError { + info: error_value, + kind: Net::Offline, + }), + "SyncErrorConflict" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::Conflict, + }), + "SyncErrorServerError" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::ServerError, + }), + "SyncErrorClientTooOld" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::ClientTooOld, + }), + "SyncErrorAuthFailed" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::AuthFailed, + }), + "SyncErrorServerMessage" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::ServerMessage, + }), + "SyncErrorClockIncorrect" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::ClockIncorrect, + }), + "SyncErrorOther" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::Other, + }), + "SyncErrorResyncRequired" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::ResyncRequired, + }), + "SyncErrorDatabaseCheckRequired" => AnkiError::SyncError(SyncError { + info: error_value, + kind: Sync::DatabaseCheckRequired, + }), + "JSONError" => AnkiError::JsonError(error_value), + "ProtoError" => AnkiError::ProtoError(error_value), + "Interrupted" => AnkiError::Interrupted, + "CollectionNotOpen" => AnkiError::CollectionNotOpen, + "CollectionAlreadyOpen" => AnkiError::CollectionAlreadyOpen, + "NotFound" => AnkiError::NotFound, + "Existing" => AnkiError::Existing, + "FilteredDeckError" => { + AnkiError::FilteredDeckError(FilteredDeckError::FilteredDeckRequired) + } + "SearchError" => AnkiError::SearchError(SearchErrorKind::EmptyGroup), + "FatalError" => AnkiError::FatalError(error_value), + unknown => AnkiError::FatalError(format!("Unknown Error code: {}", unknown)), + }; + err +} diff --git a/rslib/src/backend/ankidroid/mod.rs b/rslib/src/backend/ankidroid/mod.rs new file mode 100644 index 00000000000..b1a4680c21c --- /dev/null +++ b/rslib/src/backend/ankidroid/mod.rs @@ -0,0 +1,122 @@ +// Copyright: Ankitects Pty Ltd and contributors +// License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html + +pub(crate) mod db; +pub(crate) mod error; + +use self::{db::active_sequences, error::debug_produce_error}; +use super::{ + dbproxy::{db_command_bytes, db_command_proto}, + Backend, +}; +pub(super) use crate::pb::ankidroid_service::Service as AnkidroidService; +use crate::{ + backend::ankidroid::db::{execute_for_row_count, insert_for_id}, + pb::{ + self as pb, + ankidroid::{DbResponse, GetActiveSequenceNumbersResponse, GetNextResultPageRequest}, + generic::{self, Empty, Int32, Json}, + Int64, StringList, + }, + prelude::*, + scheduler::timing::{self, fixed_offset_from_minutes}, +}; + +impl AnkidroidService for Backend { + fn sched_timing_today_legacy( + &self, + input: pb::SchedTimingTodayLegacyRequest, + ) -> Result { + let result = timing::sched_timing_today( + TimestampSecs::from(input.created_secs), + TimestampSecs::from(input.now_secs), + input.created_mins_west.map(fixed_offset_from_minutes), + fixed_offset_from_minutes(input.now_mins_west), + Some(input.rollover_hour as u8), + ); + Ok(pb::SchedTimingTodayResponse::from(result)) + } + + fn local_minutes_west_legacy(&self, input: pb::Int64) -> Result { + Ok(pb::Int32 { + val: timing::local_minutes_west_for_stamp(input.val), + }) + } + + fn run_db_command(&self, input: Json) -> Result { + self.with_col(|col| db_command_bytes(col, &input.json)) + .map(|json| Json { json }) + } + + fn run_db_command_proto(&self, input: Json) -> Result { + self.with_col(|col| db_command_proto(col, &input.json)) + } + + fn run_db_command_for_row_count(&self, input: Json) -> Result { + self.with_col(|col| execute_for_row_count(col, &input.json)) + .map(|val| Int64 { val }) + } + + fn flush_all_queries(&self, _input: Empty) -> Result { + self.with_col(|col| { + db::flush_all(backend_id(col)); + Ok(Empty {}) + }) + } + + fn flush_query(&self, input: Int32) -> Result { + self.with_col(|col| { + db::flush_cache(backend_id(col), input.val); + Ok(Empty {}) + }) + } + + fn get_next_result_page(&self, input: GetNextResultPageRequest) -> Result { + self.with_col(|col| { + let id = backend_id(col); + db::get_next(id, input.sequence, input.index).ok_or(AnkiError::NotFound) + }) + } + + fn insert_for_id(&self, input: Json) -> Result { + self.with_col(|col| insert_for_id(col, &input.json).map(Into::into)) + } + + fn set_page_size(&self, input: Int64) -> Result { + // we don't require an open collection, but should avoid modifying this + // concurrently + let _guard = self.col.lock(); + db::set_max_page_size(input.val as usize); + Ok(().into()) + } + + fn get_column_names_from_query(&self, input: generic::String) -> Result { + self.with_col(|col| { + let stmt = col.storage.db.prepare(&input.val)?; + let names = stmt.column_names(); + let names: Vec<_> = names.iter().map(ToString::to_string).collect(); + Ok(names.into()) + }) + } + + fn get_active_sequence_numbers( + &self, + _input: Empty, + ) -> Result { + self.with_col(|col| { + Ok(GetActiveSequenceNumbersResponse { + numbers: active_sequences(backend_id(col)), + }) + }) + } + + fn debug_produce_error(&self, input: generic::String) -> Result { + Err(debug_produce_error(&input.val)) + } +} + +/// The old AnkiDroid code used the pointer to the backend as a cache index; +/// Now we use a pointer to SqliteStorage instead. +pub(crate) fn backend_id(col: &Collection) -> i64 { + (&col.storage as *const _) as i64 +} diff --git a/rslib/src/backend/collection.rs b/rslib/src/backend/collection.rs index e1f12c350b3..be524d80bc9 100644 --- a/rslib/src/backend/collection.rs +++ b/rslib/src/backend/collection.rs @@ -31,6 +31,7 @@ impl CollectionService for Backend { let mut builder = CollectionBuilder::new(input.collection_path); builder + .set_force_schema11(input.force_schema11) .set_media_paths(input.media_folder_path, input.media_db_path) .set_server(self.server) .set_tr(self.tr.clone()); diff --git a/rslib/src/backend/dbproxy.rs b/rslib/src/backend/dbproxy.rs index f43a0c51721..afbbbda3744 100644 --- a/rslib/src/backend/dbproxy.rs +++ b/rslib/src/backend/dbproxy.rs @@ -8,7 +8,13 @@ use rusqlite::{ }; use serde_derive::{Deserialize, Serialize}; -use crate::{prelude::*, storage::SqliteStorage}; +use super::ankidroid::backend_id; +use crate::{ + pb, + pb::{sql_value::Data, DbResponse, DbResult as ProtoDbResult, Row, SqlValue as pb_SqlValue}, + prelude::*, + storage::SqliteStorage, +}; #[derive(Deserialize)] #[serde(tag = "kind", rename_all = "lowercase")] @@ -57,6 +63,42 @@ impl ToSql for SqlValue { } } +impl From<&SqlValue> for pb::SqlValue { + fn from(item: &SqlValue) -> Self { + match item { + SqlValue::Null => pb_SqlValue { data: Option::None }, + SqlValue::String(s) => pb_SqlValue { + data: Some(Data::StringValue(s.to_string())), + }, + SqlValue::Int(i) => pb_SqlValue { + data: Some(Data::LongValue(*i)), + }, + SqlValue::Double(d) => pb_SqlValue { + data: Some(Data::DoubleValue(*d)), + }, + SqlValue::Blob(b) => pb_SqlValue { + data: Some(Data::BlobValue(b.clone())), + }, + } + } +} + +impl From<&Vec> for pb::Row { + fn from(item: &Vec) -> Self { + Row { + fields: item.iter().map(pb::SqlValue::from).collect(), + } + } +} + +impl From<&Vec>> for pb::DbResult { + fn from(item: &Vec>) -> Self { + ProtoDbResult { + rows: item.iter().map(Row::from).collect(), + } + } +} + impl FromSql for SqlValue { fn column_result(value: ValueRef<'_>) -> std::result::Result { let val = match value { @@ -71,6 +113,10 @@ impl FromSql for SqlValue { } pub(super) fn db_command_bytes(col: &mut Collection, input: &[u8]) -> Result> { + serde_json::to_vec(&db_command_bytes_inner(col, input)?).map_err(Into::into) +} + +pub(super) fn db_command_bytes_inner(col: &mut Collection, input: &[u8]) -> Result { let req: DbRequest = serde_json::from_slice(input)?; let resp = match req { DbRequest::Query { @@ -107,7 +153,7 @@ pub(super) fn db_command_bytes(col: &mut Collection, input: &[u8]) -> Result bool { head.starts_with("select") } +pub(crate) fn db_command_proto(col: &mut Collection, input: &[u8]) -> Result { + let result = db_command_bytes_inner(col, input)?; + let proto_resp = match result { + DbResult::None => ProtoDbResult { rows: Vec::new() }, + DbResult::Rows(rows) => ProtoDbResult::from(&rows), + }; + let trimmed = super::ankidroid::db::trim_and_cache_remaining( + backend_id(col), + proto_resp, + super::ankidroid::db::next_sequence_number(), + ); + Ok(trimmed) +} + pub(super) fn db_query_row(ctx: &SqliteStorage, sql: &str, args: &[SqlValue]) -> Result { let mut stmt = ctx.db.prepare_cached(sql)?; let columns = stmt.column_count(); diff --git a/rslib/src/backend/error.rs b/rslib/src/backend/error.rs index a385a2ce862..b3e4d20aeb7 100644 --- a/rslib/src/backend/error.rs +++ b/rslib/src/backend/error.rs @@ -9,10 +9,11 @@ use crate::{ }; impl AnkiError { - pub(super) fn into_protobuf(self, tr: &I18n) -> pb::BackendError { + pub fn into_protobuf(self, tr: &I18n) -> pb::BackendError { let localized = self.localized_description(tr); let help_page = self.help_page().map(|page| page as i32); let kind = match self { + AnkiError::FatalError(_) => Kind::FatalError, AnkiError::InvalidInput(_) => Kind::InvalidInput, AnkiError::TemplateError(_) => Kind::TemplateParse, AnkiError::IoError(_) => Kind::IoError, diff --git a/rslib/src/backend/mod.rs b/rslib/src/backend/mod.rs index 818f8ffed54..c767594a8a1 100644 --- a/rslib/src/backend/mod.rs +++ b/rslib/src/backend/mod.rs @@ -5,6 +5,7 @@ #![allow(clippy::unnecessary_wraps)] mod adding; +mod ankidroid; mod card; mod cardrendering; mod collection; @@ -41,6 +42,7 @@ use slog::Logger; use tokio::runtime::{self, Runtime}; use self::{ + ankidroid::AnkidroidService, card::CardsService, cardrendering::CardRenderingService, collection::CollectionService, @@ -128,6 +130,7 @@ impl Backend { pb::ServiceIndex::from_i32(service as i32) .ok_or_else(|| AnkiError::invalid_input("invalid service")) .and_then(|service| match service { + pb::ServiceIndex::Ankidroid => AnkidroidService::run_method(self, method, input), pb::ServiceIndex::Scheduler => SchedulerService::run_method(self, method, input), pb::ServiceIndex::Decks => DecksService::run_method(self, method, input), pb::ServiceIndex::Notes => NotesService::run_method(self, method, input), diff --git a/rslib/src/collection/mod.rs b/rslib/src/collection/mod.rs index e1222804f3c..55a50c38224 100644 --- a/rslib/src/collection/mod.rs +++ b/rslib/src/collection/mod.rs @@ -30,6 +30,7 @@ pub struct CollectionBuilder { server: Option, tr: Option, log: Option, + force_schema11: Option, } impl CollectionBuilder { @@ -51,8 +52,8 @@ impl CollectionBuilder { let media_folder = self.media_folder.clone().unwrap_or_default(); let media_db = self.media_db.clone().unwrap_or_default(); let log = self.log.clone().unwrap_or_else(crate::log::terminal); - - let storage = SqliteStorage::open_or_create(&col_path, &tr, server)?; + let force_schema11 = self.force_schema11.unwrap_or_default(); + let storage = SqliteStorage::open_or_create(&col_path, &tr, server, force_schema11)?; let col = Collection { storage, col_path, @@ -94,6 +95,11 @@ impl CollectionBuilder { self } + pub fn set_force_schema11(&mut self, force: bool) -> &mut Self { + self.force_schema11 = Some(force); + self + } + /// Log to the provided file. pub fn set_log_file(&mut self, log_file: &str) -> Result<&mut Self, std::io::Error> { self.set_logger(default_logger(Some(log_file))?); diff --git a/rslib/src/error/mod.rs b/rslib/src/error/mod.rs index 32c6f9afa93..f9e3af03a0a 100644 --- a/rslib/src/error/mod.rs +++ b/rslib/src/error/mod.rs @@ -20,6 +20,7 @@ pub type Result = std::result::Result; #[derive(Debug, PartialEq)] pub enum AnkiError { + FatalError(String), InvalidInput(String), TemplateError(String), CardTypeError(CardTypeError), @@ -119,6 +120,7 @@ impl AnkiError { AnkiError::FileIoError(err) => { format!("{}: {}", err.path, err.error) } + AnkiError::FatalError(err) => err.to_owned(), } } diff --git a/rslib/src/pb.rs b/rslib/src/pb.rs index f3b6eff71f1..b782fd5ba45 100644 --- a/rslib/src/pb.rs +++ b/rslib/src/pb.rs @@ -13,6 +13,7 @@ macro_rules! protobuf { }; } +protobuf!(ankidroid); protobuf!(backend); protobuf!(card_rendering); protobuf!(cards); diff --git a/rslib/src/storage/card/mod.rs b/rslib/src/storage/card/mod.rs index 96b31448624..34d40e946a7 100644 --- a/rslib/src/storage/card/mod.rs +++ b/rslib/src/storage/card/mod.rs @@ -745,7 +745,8 @@ mod test { #[test] fn add_card() { let tr = I18n::template_only(); - let storage = SqliteStorage::open_or_create(Path::new(":memory:"), &tr, false).unwrap(); + let storage = + SqliteStorage::open_or_create(Path::new(":memory:"), &tr, false, false).unwrap(); let mut card = Card::default(); storage.add_card(&mut card).unwrap(); let id1 = card.id; diff --git a/rslib/src/storage/sqlite.rs b/rslib/src/storage/sqlite.rs index 95a25dd92d7..ceb0d9a6522 100644 --- a/rslib/src/storage/sqlite.rs +++ b/rslib/src/storage/sqlite.rs @@ -205,7 +205,12 @@ fn trace(s: &str) { } impl SqliteStorage { - pub(crate) fn open_or_create(path: &Path, tr: &I18n, server: bool) -> Result { + pub(crate) fn open_or_create( + path: &Path, + tr: &I18n, + server: bool, + force_schema11: bool, + ) -> Result { let db = open_or_create_collection_db(path)?; let (create, ver) = schema_version(&db)?; @@ -250,6 +255,13 @@ impl SqliteStorage { let storage = Self { db }; + if force_schema11 { + if create || upgrade { + storage.commit_trx()?; + } + return storage_with_schema11(storage, ver); + } + if create || upgrade { storage.upgrade_to_latest_schema(ver, server)?; } @@ -370,3 +382,22 @@ impl SqliteStorage { self.db.query_row(sql, [], |r| r.get(0)).map_err(Into::into) } } + +fn storage_with_schema11(storage: SqliteStorage, ver: u8) -> Result { + if ver != 11 { + if ver != SCHEMA_MAX_VERSION { + // partially upgraded; need to fully upgrade before downgrading + storage.begin_trx()?; + storage.upgrade_to_latest_schema(ver, false)?; + storage.commit_trx()?; + } + storage.downgrade_to(SchemaVersion::V11)?; + } + // Requery uses "TRUNCATE" by default if WAL is not enabled. + // We copy this behaviour here. See https://github.com/ankidroid/Anki-Android/pull/7977 for + // analysis. We may be able to enable WAL at a later time. + storage + .db + .pragma_update(None, "journal_mode", &"TRUNCATE")?; + Ok(storage) +}