From 03a9d6b9b22e0421264c5c4675292136337eac47 Mon Sep 17 00:00:00 2001 From: David Allison <62114487+david-allison-1@users.noreply.github.com> Date: Thu, 1 Apr 2021 04:59:45 +0100 Subject: [PATCH] feat: fix OOM via memory limit on DB row results We now use a 2MB page size, the same as CursorWindow.sCursorWindowSize We occasionally got OOMs on methods which returned unbounded data eg. getting the field data for notes: "Check Database" crashed To fix this, instead of saying to the Rust "we want 1000 rows max", we say "we want max 2MB of data" The calculation for the data is inexact - string length, and 8 bytes for doubles/longs. Hasn't been tested thoroughly, but seems to only be < ~7% off the protobuf size for a string-only column (underestimate). We measure the rust in-memory usage rather than the size of the serialized protobuf Measuring the serialized size wasn't researched, but was assumed to be hard, as we would need to stream into a protobuf collection, and be able to efficiently query the new size for each row we add. Main changes: StreamingProtobufSQLiteCursor - no longer use pages dbcommand: allow access via an offset to the result set rather than via pages Adds: setDbPageSize to allow the change of the size for debugging Adds: Unit tests for the rust - not yet executed in CI (#51) rename: getPage -> getNextSlice bumps `anki` commit to add field sqlite.proto#DbResponse:start_index Fixes #14 (no longer necessary) Fixes ankidroid/#8178 --- anki | 2 +- .../rsdroid/DatabaseIntegrationTests.java | 13 +- .../rsdroid/ankiutil/InstrumentedTest.java | 11 + .../StreamingProtobufSQLiteCursorTest.java | 74 +++++- .../net/ankiweb/rsdroid/BackendMutex.java | 4 +- .../net/ankiweb/rsdroid/BackendV1Impl.java | 13 +- .../net/ankiweb/rsdroid/NativeMethods.java | 8 +- .../ankiweb/rsdroid/database/SQLHandler.java | 2 +- .../net/ankiweb/rsdroid/database/Session.java | 4 +- .../StreamingProtobufSQLiteCursor.java | 30 ++- rslib-bridge/src/dbcommand.rs | 250 +++++++++++++++++- rslib-bridge/src/lib.rs | 42 +-- 12 files changed, 384 insertions(+), 69 deletions(-) diff --git a/anki b/anki index c597f2630..c9e12052e 160000 --- a/anki +++ b/anki @@ -1 +1 @@ -Subproject commit c597f2630d397ab441e81ee15895cf1df46fafb8 +Subproject commit c9e12052e2c3d6c146f06461abef6f60ef8cca58 diff --git a/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/DatabaseIntegrationTests.java b/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/DatabaseIntegrationTests.java index 56f53d46f..91f7ec6e7 100644 --- a/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/DatabaseIntegrationTests.java +++ b/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/DatabaseIntegrationTests.java @@ -23,7 +23,6 @@ import androidx.sqlite.db.SupportSQLiteDatabase; import net.ankiweb.rsdroid.ankiutil.DatabaseUtil; -import net.ankiweb.rsdroid.database.StreamingProtobufSQLiteCursor; import net.ankiweb.rsdroid.database.testutils.DatabaseComparison; import org.junit.Test; @@ -38,6 +37,10 @@ @RunWith(Parameterized.class) public class DatabaseIntegrationTests extends DatabaseComparison { + private static final int INT_SIZE_BYTES = 8; + private static final int OPTIONAL_BYTES = 1; + /** Number of integers in 1 page of DB results when under test (111) */ + public static int DB_PAGE_NUM_INT_ELEMENTS = TEST_PAGE_SIZE / (INT_SIZE_BYTES + OPTIONAL_BYTES); @Test public void testScalar() { @@ -254,12 +257,12 @@ public void testRowCountPage() { db.execSQL("create table test (id int)"); - for (int i = 0; i < StreamingProtobufSQLiteCursor.RUST_PAGE_SIZE; i++) { + for (int i = 0; i < DB_PAGE_NUM_INT_ELEMENTS; i++) { db.execSQL("insert into test VALUES (1)"); } try (Cursor c = db.query("select * from test")) { - assertThat(c.getCount(), is(StreamingProtobufSQLiteCursor.RUST_PAGE_SIZE)); + assertThat(c.getCount(), is(DB_PAGE_NUM_INT_ELEMENTS)); } } @@ -269,12 +272,12 @@ public void testRowCountPageAndOne() { db.execSQL("create table test (id int)"); - for (int i = 0; i < StreamingProtobufSQLiteCursor.RUST_PAGE_SIZE + 1; i++) { + for (int i = 0; i < DB_PAGE_NUM_INT_ELEMENTS + 1; i++) { db.execSQL("insert into test VALUES (1)"); } try (Cursor c = db.query("select * from test")) { - assertThat(c.getCount(), is(StreamingProtobufSQLiteCursor.RUST_PAGE_SIZE + 1)); + assertThat(c.getCount(), is(DB_PAGE_NUM_INT_ELEMENTS + 1)); } } diff --git a/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/ankiutil/InstrumentedTest.java b/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/ankiutil/InstrumentedTest.java index 2ae39d68d..088321a3f 100644 --- a/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/ankiutil/InstrumentedTest.java +++ b/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/ankiutil/InstrumentedTest.java @@ -25,6 +25,8 @@ import net.ankiweb.rsdroid.BackendFactory; import net.ankiweb.rsdroid.BackendUtils; import net.ankiweb.rsdroid.BackendV1; +import net.ankiweb.rsdroid.BackendV1Impl; +import net.ankiweb.rsdroid.NativeMethods; import net.ankiweb.rsdroid.RustBackendFailedException; import org.junit.After; @@ -44,6 +46,8 @@ public class InstrumentedTest { private final List backendList = new ArrayList<>(); + protected final static int TEST_PAGE_SIZE = 1000; + @Before public void before() { /* @@ -51,6 +55,13 @@ public void before() { Timber.uprootAll(); Timber.plant(new Timber.DebugTree()); */ + + try { + NativeMethods.ensureSetup(); + } catch (RustBackendFailedException e) { + throw new RuntimeException(e); + } + BackendV1Impl.setPageSize(TEST_PAGE_SIZE); } @After diff --git a/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursorTest.java b/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursorTest.java index 9ff85973d..49f63f276 100644 --- a/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursorTest.java +++ b/rsdroid-instrumented/src/androidTest/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursorTest.java @@ -21,19 +21,19 @@ import androidx.sqlite.db.SupportSQLiteDatabase; import net.ankiweb.rsdroid.BackendV1; +import net.ankiweb.rsdroid.DatabaseIntegrationTests; import net.ankiweb.rsdroid.ankiutil.InstrumentedTest; -import org.junit.Ignore; import org.junit.Test; import java.io.IOException; +import java.util.HashSet; +import java.util.Set; import timber.log.Timber; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -import static org.junit.Assert.fail; public class StreamingProtobufSQLiteCursorTest extends InstrumentedTest { @@ -75,7 +75,7 @@ private void iterateAllRows(SupportSQLiteDatabase db) { @Test public void testCorruptionIsHandled() throws IOException { - int elements = StreamingProtobufSQLiteCursor.RUST_PAGE_SIZE; + int elements = DatabaseIntegrationTests.DB_PAGE_NUM_INT_ELEMENTS; try (BackendV1 backend = super.getBackend("initial_version_2_12_1.anki2")) { SupportSQLiteDatabase db = new RustSupportSQLiteOpenHelper(backend).getWritableDatabase(); @@ -109,4 +109,70 @@ public void testCorruptionIsHandled() throws IOException { } } + + @Test + public void smallQueryHasOneCount() throws IOException { + int elements = 30; // 465 + + + try (BackendV1 backend = super.getBackend("initial_version_2_12_1.anki2")) { + SupportSQLiteDatabase db = new RustSupportSQLiteOpenHelper(backend).getWritableDatabase(); + + db.execSQL("create table tmp (id varchar)"); + for (int i = 0; i < elements + 1; i++) { + String inputOfLength = new String(new char[elements]).replace("\0", "a"); + db.execSQL("insert into tmp (id) values (?)", new Object[] {inputOfLength}); + } + + try (TestCursor c1 = new TestCursor(backend, "select * from tmp", new Object[] { })) { + + Set sizes = new HashSet<>(); + + while (c1.moveToNext()) { + if (sizes.add(c1.getSliceSize()) && sizes.size() > 1) { + throw new IllegalStateException("Expected single size of results"); + } + } + } + } + } + + @Test + public void variableLengthStringsReturnDifferentRowCounts() throws IOException { + int elements = 50; // 1275 > 1000 + + try (BackendV1 backend = super.getBackend("initial_version_2_12_1.anki2")) { + SupportSQLiteDatabase db = new RustSupportSQLiteOpenHelper(backend).getWritableDatabase(); + + db.execSQL("create table tmp (id varchar)"); + for (int i = 0; i < elements + 1; i++) { + String inputOfLength = new String(new char[elements]).replace("\0", "a"); + db.execSQL("insert into tmp (id) values (?)", new Object[] {inputOfLength}); + } + + try (TestCursor c1 = new TestCursor(backend, "select * from tmp", new Object[] { })) { + + Set sizes = new HashSet<>(); + + while (c1.moveToNext()) { + if (sizes.add(c1.getSliceSize()) && sizes.size() > 1) { + return; + } + } + + throw new IllegalStateException("Expected multiple sizes of results"); + } + } + } + + private static class TestCursor extends StreamingProtobufSQLiteCursor { + + public TestCursor(SQLHandler backend, String query, Object[] bindArgs) { + super(backend, query, bindArgs); + } + + public int getSliceSize() { + return getCurrentSliceRowCount(); + } + } } diff --git a/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendMutex.java b/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendMutex.java index 31337bf2a..6f48c9701 100644 --- a/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendMutex.java +++ b/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendMutex.java @@ -139,10 +139,10 @@ public String getPath() { } @Override - public Sqlite.DBResponse getPage(int page, int sequenceNumber) { + public Sqlite.DBResponse getNextSlice(long startIndex, int sequenceNumber) { try { lock.lock(); - return backend.getPage(page, sequenceNumber); + return backend.getNextSlice(startIndex, sequenceNumber); } finally { lock.unlock(); } diff --git a/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendV1Impl.java b/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendV1Impl.java index c25614b57..714f98a11 100644 --- a/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendV1Impl.java +++ b/rsdroid/src/main/java/net/ankiweb/rsdroid/BackendV1Impl.java @@ -18,6 +18,7 @@ import androidx.annotation.CheckResult; import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; import com.google.protobuf.InvalidProtocolBufferException; @@ -270,13 +271,13 @@ public Sqlite.DBResponse fullQueryProto(String query, Object... args) { } @Override - public Sqlite.DBResponse getPage(int page, int sequenceNumber) { + public Sqlite.DBResponse getNextSlice(long startIndex, int sequenceNumber) { byte[] result = null; try { - Timber.d("Rust: getPage %d", page); + Timber.d("Rust: getNextSlice %d", startIndex); Pointer backend = ensureBackend(); - result = NativeMethods.databaseGetNextResultPage(backend.toJni(), sequenceNumber, page); + result = NativeMethods.databaseGetNextResultPage(backend.toJni(), sequenceNumber, startIndex); Sqlite.DBResponse message = Sqlite.DBResponse.parseFrom(result); validateMessage(result, message); @@ -340,6 +341,12 @@ private void performTransaction(String kind) { } } + @VisibleForTesting(otherwise = VisibleForTesting.NONE) + public static void setPageSize(long pageSizeInBytes) { + // TODO: Make this nonstatic + NativeMethods.setDbPageSize(pageSizeInBytes); + } + @Override public String[] getColumnNames(String sql) { diff --git a/rsdroid/src/main/java/net/ankiweb/rsdroid/NativeMethods.java b/rsdroid/src/main/java/net/ankiweb/rsdroid/NativeMethods.java index a4cd90f4b..e4319c1d1 100644 --- a/rsdroid/src/main/java/net/ankiweb/rsdroid/NativeMethods.java +++ b/rsdroid/src/main/java/net/ankiweb/rsdroid/NativeMethods.java @@ -72,7 +72,7 @@ static void execCommand(long backendPointer, final int command, byte[] args) { /** Returns the next page of results after a databaseCommand. * @return DbResult object */ @CheckResult - static native byte[] databaseGetNextResultPage(long backendPointer, int sequenceNumber, int page); + static native byte[] databaseGetNextResultPage(long backendPointer, int sequenceNumber, long startIndex); /** Clears the memory from the current protobuf query. */ static native int cancelCurrentProtoQuery(long backendPointer, int sequenceNumber); @@ -98,6 +98,12 @@ static void execCommand(long backendPointer, final int command, byte[] args) { static native byte[] executeAnkiDroidCommand(long backendPointer, int command, byte[] args); + /** + * Sets the maximum number of bytes that a page of database results should return + * {@link net.ankiweb.rsdroid.database.StreamingProtobufSQLiteCursor} + */ + static native void setDbPageSize(long numberOfBytes); + /** * Produces all possible Rust-based errors. */ diff --git a/rsdroid/src/main/java/net/ankiweb/rsdroid/database/SQLHandler.java b/rsdroid/src/main/java/net/ankiweb/rsdroid/database/SQLHandler.java index 8ce9632e5..3b7263591 100644 --- a/rsdroid/src/main/java/net/ankiweb/rsdroid/database/SQLHandler.java +++ b/rsdroid/src/main/java/net/ankiweb/rsdroid/database/SQLHandler.java @@ -41,7 +41,7 @@ public interface SQLHandler { String getPath(); /* Protobuf-related (#6) */ - Sqlite.DBResponse getPage(int page, int sequenceNumber); + Sqlite.DBResponse getNextSlice(long startIndex, int sequenceNumber); Sqlite.DBResponse fullQueryProto(String query, Object... bindArgs); void cancelCurrentProtoQuery(int sequenceNumber); diff --git a/rsdroid/src/main/java/net/ankiweb/rsdroid/database/Session.java b/rsdroid/src/main/java/net/ankiweb/rsdroid/database/Session.java index 972eac25b..4ff017ce0 100644 --- a/rsdroid/src/main/java/net/ankiweb/rsdroid/database/Session.java +++ b/rsdroid/src/main/java/net/ankiweb/rsdroid/database/Session.java @@ -86,8 +86,8 @@ public String getPath() { } @Override - public Sqlite.DBResponse getPage(int page, int sequenceNumber) { - return backend.getPage(page, sequenceNumber); + public Sqlite.DBResponse getNextSlice(long startIndex, int sequenceNumber) { + return backend.getNextSlice(startIndex, sequenceNumber); } @Override diff --git a/rsdroid/src/main/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursor.java b/rsdroid/src/main/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursor.java index 6e27f5a76..53bd4d364 100644 --- a/rsdroid/src/main/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursor.java +++ b/rsdroid/src/main/java/net/ankiweb/rsdroid/database/StreamingProtobufSQLiteCursor.java @@ -27,28 +27,35 @@ import BackendProto.Sqlite; public class StreamingProtobufSQLiteCursor extends AnkiDatabaseCursor { - // Interleaved cursors would corrupt data if there are more than PAGE_SIZE results. - // We currently use mSequenceNumber to crash if this is the case - - // MAINTENANCE: This is not obtained from the Rust, so must manually be kept in sync - public static final int RUST_PAGE_SIZE = 1000; + /** + * Rust Implementation: + * + * When we request a query, rust calculates 2MB (default) of results and sends it to us + * + * We keep track of where we are with getSliceStartIndex: the index into the rust collection + * + * The next request should be for index: getSliceStartIndex() + getCurrentSliceRowCount() + */ private final SQLHandler backend; private final String query; private Sqlite.DBResponse results; private int position = -1; - private int page = -1; private String[] columnMapping; private boolean isClosed = false; private final int sequenceNumber; + /** The total number of rows for the query */ private final int rowCount; + /**The current index into the collection or rows */ + private long getSliceStartIndex() { + return results.getStartIndex(); + } public StreamingProtobufSQLiteCursor(SQLHandler backend, String query, Object[] bindArgs) { this.backend = backend; this.query = query; - page++; try { results = this.backend.fullQueryProto(this.query, bindArgs); sequenceNumber = results.getSequenceNumber(); @@ -59,11 +66,10 @@ public StreamingProtobufSQLiteCursor(SQLHandler backend, String query, Object[] } private void getNextPage() { - page++; position = -1; try { - results = backend.getPage(page, sequenceNumber); + results = backend.getNextSlice(getSliceStartIndex() + getCurrentSliceRowCount(), sequenceNumber); if (results.getSequenceNumber() != sequenceNumber) { throw new IllegalStateException("rsdroid does not currently handle nested cursor-based queries. Please change the code to avoid holding a reference to the query, or implement the functionality in rsdroid"); } @@ -79,7 +85,7 @@ public int getCount() { @Override public int getPosition() { - return position; + return (int) getSliceStartIndex() + position; } @Override @@ -93,7 +99,7 @@ public boolean moveToFirst() { @Override public boolean moveToNext() { - if (getCurrentSliceRowCount() > 0 && position + 1 >= RUST_PAGE_SIZE && getCount() != RUST_PAGE_SIZE) { + if (getCurrentSliceRowCount() > 0 && position + 1 >= getCurrentSliceRowCount() && getCount() != getCurrentSliceRowCount()) { getNextPage(); } position++; @@ -250,7 +256,7 @@ private Sqlite.SqlValue getFieldAtIndex(int columnIndex) { return getRowAtCurrentPosition().getFields(columnIndex); } - private int getCurrentSliceRowCount() { + protected int getCurrentSliceRowCount() { return results.getResult().getRowsCount(); } } diff --git a/rslib-bridge/src/dbcommand.rs b/rslib-bridge/src/dbcommand.rs index 87eb8fd46..943da044c 100644 --- a/rslib-bridge/src/dbcommand.rs +++ b/rslib-bridge/src/dbcommand.rs @@ -12,7 +12,72 @@ use anki::backend_proto::{DbResponse, DbResult}; use i64 as backend_pointer; use i64 as dbresponse_pointer; -use itertools::Itertools; +use anki::backend_proto::{Row, SqlValue}; +use std::mem::size_of; +use anki::backend_proto::sql_value::Data; +use itertools::{Itertools, FoldWhile}; +use itertools::FoldWhile::{Done, Continue}; +use std::ops::Deref; + + +pub trait Sizable { + /** Estimates the heap size of the value, in bytes */ + fn estimate_size(&self) -> usize; +} + +impl Sizable for Data { + fn estimate_size(&self) -> usize { + match self { + Data::StringValue(s) => { s.len() } + Data::LongValue(_) => { size_of::() } + Data::DoubleValue(_) => { size_of::() } + Data::BlobValue(b) => { b.len() } + } + } +} + +impl Sizable for SqlValue { + fn estimate_size(&self) -> usize { + // Add a byte for the optional + self.data.as_ref().map(|f| f.estimate_size() + 1).unwrap_or(1) + } +} + +impl Sizable for Row { + fn estimate_size(&self) -> usize { + self.fields.iter().map(|x| x.estimate_size()).sum() + } +} + +impl Sizable for DbResult { + fn estimate_size(&self) -> usize { + // Performance: It might be best to take the first x rows and determine the data types + // If we have floats or longs, they'll be a fixed size (excluding nulls) and should speed + // up the calculation as we'll only calculate a subset of the columns. + self.rows.iter().map(|x| x.estimate_size()).sum() + } +} + +pub(crate) fn select_next_slice<'a>(mut rows : impl Iterator) -> Vec { + select_slice_of_size(rows, get_max_page_size()).into_inner().1 +} + +fn select_slice_of_size<'a>(mut rows : impl Iterator, max_size: usize) -> FoldWhile<(usize, Vec)> { + let init: Vec = Vec::new(); + let folded = rows.fold_while((0, init), |mut acc, x| { + let new_size = acc.0 + x.estimate_size(); + // If the accumulator is 0, but we're over the size: return a single result so we don't loop forever. + // Theoretically, this shouldn't happen as data should be reasonably sized + if new_size > max_size && acc.0 > 0 { + Done(acc) + } else { + // PERF: should be faster to return (size, numElements) then bulk copy/slice + acc.1.push(x.to_owned()); + Continue((new_size, acc.1)) + } + }); + folded +} lazy_static! { // backend_pointer => Map @@ -72,7 +137,28 @@ pub(crate) fn active_sequences(ptr : backend_pointer) -> Vec { } } -pub(crate) fn insert_cache(ptr : backend_pointer, result : DbResponse) { +/** +Store the data in the cache if larger than than the page size.
+Returns: The data capped to the page size +*/ +pub(crate) fn trim_and_cache_remaining(backend_ptr: i64, values: DbResult, sequence_number: i32) -> DbResponse { + let start_index = 0; + + // PERF: Could speed this up by not creating the vector and just calculating the count + let first_result = select_next_slice(values.rows.iter()); + + let row_count = values.rows.len() as i32; + if first_result.len() < values.rows.len() { + let to_store = DbResponse { result: Some(values), sequence_number, row_count, start_index }; + insert_cache(backend_ptr, to_store); + + DbResponse { result: Some(DbResult { rows: first_result }), sequence_number, row_count, start_index } + } else { + DbResponse { result: Some(values), sequence_number, row_count, start_index } + } +} + +fn insert_cache(ptr : backend_pointer, result : DbResponse) { let mut map = HASHMAP.lock().unwrap(); match map.get_mut(&ptr) { @@ -88,7 +174,22 @@ pub(crate) fn insert_cache(ptr : backend_pointer, result : DbResponse) { out_hash_map.insert(result.sequence_number, Box::into_raw(Box::new(result)) as dbresponse_pointer); } -pub(crate) unsafe fn get_next(ptr : backend_pointer, sequence_number : i32, offset : usize, to_take : usize) -> Option { +pub(crate) unsafe fn get_next(ptr : backend_pointer, sequence_number : i32, start_index : i64) -> Option { + let result = get_next_result(ptr, &sequence_number, start_index); + + match result.as_ref() { + Some(x) => { + if x.result.is_none() || x.result.as_ref().unwrap().rows.is_empty() { + flush_cache(&ptr, sequence_number) + } + }, + None => {} + } + + result +} + +unsafe fn get_next_result(ptr: backend_pointer, sequence_number: &i32, start_index: i64) -> Option { let map = HASHMAP.lock().unwrap(); let result_map = map.get(&ptr)?; @@ -97,13 +198,18 @@ pub(crate) unsafe fn get_next(ptr : backend_pointer, sequence_number : i32, offs let current_result = &mut *(backend_ptr as *mut DbResponse); - let result = DbResult { rows: current_result.result.as_ref().unwrap_or(&DbResult { rows: Vec::new()} ).rows.iter().skip(offset).take(to_take).cloned().collect() }; + // TODO: This shouldn't need to exist + let tmp: Vec = Vec::new(); + let next_rows = current_result.result.as_ref().map(|x| x.rows.iter()).unwrap_or(tmp.iter()); - if result.rows.is_empty() { - flush_cache(&ptr, sequence_number) - } + let skipped_rows = next_rows.clone().skip(start_index as usize).collect_vec(); + println!("{}", skipped_rows.len()); + + let filtered_rows = select_next_slice(next_rows.skip(start_index as usize)); + + let result = DbResult { rows: filtered_rows }; - let trimmed_result = DbResponse { result: Some(result), sequence_number: current_result.sequence_number, row_count: current_result.row_count }; + let trimmed_result = DbResponse { result: Some(result), sequence_number: current_result.sequence_number, row_count: current_result.row_count, start_index }; Some(trimmed_result) } @@ -113,4 +219,132 @@ static mut SEQUENCE_NUMBER: i32 = 0; pub(crate) unsafe fn next_sequence_number() -> i32 { SEQUENCE_NUMBER = SEQUENCE_NUMBER + 1; SEQUENCE_NUMBER +} + +lazy_static!{ + // same as we get from io.requery.android.database.CursorWindow.sCursorWindowSize + static ref DB_COMMAND_PAGE_SIZE: Mutex = Mutex::new(1024 * 1024 * 2); +} + +pub(crate) fn set_max_page_size(size: usize) { + let mut state = DB_COMMAND_PAGE_SIZE.lock().expect("Could not lock mutex"); + *state = size; +} + +fn get_max_page_size() -> usize { + *DB_COMMAND_PAGE_SIZE.lock().unwrap() +} + + +#[cfg(test)] +mod tests { + use super::*; + + use anki::backend_proto::{sql_value, Row, SqlValue}; + use crate::dbcommand::{Sizable, select_slice_of_size}; + use std::borrow::Borrow; + + fn gen_data() -> Vec { + vec![ + SqlValue{ + data: Some(sql_value::Data::DoubleValue(12.0)) + }, + SqlValue{ + data: Some(sql_value::Data::LongValue(12)) + }, + SqlValue{ + data: Some(sql_value::Data::StringValue("Hellooooooo World".to_string())) + }, + SqlValue{ + data: Some(sql_value::Data::BlobValue(vec![])) + } + ] + } + + #[test] + fn test_size_estimate() { + let row = Row { fields: gen_data() }; + let result = DbResult { rows: vec![row.clone(), row.clone()] }; + + let actual_size = result.estimate_size(); + + let expected_size = (17 + 8 + 8) * 2; // 1 variable string, 1 long, 1 float + let expected_overhead = (4 * 1) * 2; // 4 optional columns + + assert_eq!(actual_size, expected_overhead + expected_size); + } + + #[test] + fn test_stream_size() { + let row = Row { fields: gen_data() }; + let result = DbResult { rows: vec![row.clone(), row.clone(), row.clone()] }; + let limit = 74 + 1; // two rows are 74 + + let result = select_slice_of_size(result.rows.iter(), limit).into_inner(); + + assert_eq!(2, result.1.len(), "The final element should not be included"); + assert_eq!(74, result.0, "The size should be the size of the first two objects"); + } + + #[test] + fn test_stream_size_too_small() { + let row = Row { fields: gen_data() }; + let result = DbResult { rows: vec![row.clone()] }; + let limit = 1; + + let result = select_slice_of_size(result.rows.iter(), limit).into_inner(); + + assert_eq!(1, result.1.len(), "If the limit is too small, a result is still returned"); + assert_eq!(37, result.0, "The size should be the size of the first objects"); + } + + const BACKEND_PTR: i64 = 12; + const SEQUENCE_NUMBER: i32 = 1; + + fn get(index : i64) -> Option { + unsafe { return get_next(BACKEND_PTR, SEQUENCE_NUMBER, index) }; + } + + fn get_first(result : DbResult) -> DbResponse { + trim_and_cache_remaining(BACKEND_PTR, result, SEQUENCE_NUMBER) + } + + fn seq_number_used() -> bool { + HASHMAP.lock().unwrap().get(&BACKEND_PTR).unwrap().contains_key(&SEQUENCE_NUMBER) + } + + #[test] + fn integration_test() { + let row = Row { fields: gen_data() }; + + // return one row at a time + set_max_page_size(row.estimate_size() - 1); + + let db_query_result = DbResult { rows: vec![row.clone(), row.clone()] }; + + let first_jni_response = get_first(db_query_result); + + assert_eq!(row_count(&first_jni_response), 1, "The first call should only return one row"); + + let next_index = first_jni_response.start_index + row_count(&first_jni_response); + + let second_response = get(next_index); + + assert!(second_response.is_some(), "The second response should return a value"); + let valid_second_response = second_response.unwrap(); + assert_eq!(row_count(&valid_second_response), 1); + + let final_index = valid_second_response.start_index + row_count(&valid_second_response); + + assert!(seq_number_used(), "The sequence number is assigned"); + + let final_response = get(final_index); + assert!(final_response.is_some(), "The third call should return something with no rows"); + assert_eq!(row_count(&final_response.unwrap()), 0, "The third call should return something with no rows"); + assert!(!seq_number_used(), "Sequence number data has been cleared"); + } + + fn row_count(resp: &DbResponse) -> i64 { + resp.result.as_ref().map(|x| x.rows.len()).unwrap_or(0) as i64 + } } \ No newline at end of file diff --git a/rslib-bridge/src/lib.rs b/rslib-bridge/src/lib.rs index 569af4758..059a21b6a 100644 --- a/rslib-bridge/src/lib.rs +++ b/rslib-bridge/src/lib.rs @@ -33,13 +33,6 @@ mod backend_proto; // TODO: Use a macro to handle panics to reduce code duplication -// FUTURE_EXTENSION: Allow DB_COMMAND_NUM_ROWS to be variable to allow tuning of memory usage -// Maybe also change this to a per-MB value if it's easy to stream-serialise to protobuf until a -// memory limit is hit. - -// MAINTENANCE: This must manually be kept in sync with the Java -const DB_COMMAND_NUM_ROWS: usize = 1000; - impl From for SchedTimingTodayOut2 { fn from(data: SchedTimingToday) -> Self { SchedTimingTodayOut2 { @@ -307,7 +300,7 @@ pub unsafe extern "C" fn Java_net_ankiweb_rsdroid_NativeMethods_databaseGetNextR _: JClass, backend_ptr : jlong, sequence_number: jint, - page: jint + requested_index: jlong ) -> jbyteArray { let backend = to_backend(backend_ptr); @@ -317,8 +310,7 @@ pub unsafe extern "C" fn Java_net_ankiweb_rsdroid_NativeMethods_databaseGetNextR let next_page = dbcommand::get_next( backend_ptr, sequence_number, - (page as usize) * DB_COMMAND_NUM_ROWS as usize, - DB_COMMAND_NUM_ROWS + requested_index ).unwrap(); @@ -369,7 +361,7 @@ pub unsafe extern "C" fn Java_net_ankiweb_rsdroid_NativeMethods_databaseCommand( match out_res { Ok(db_result) => { - let trimmed = trim_and_cache_remaining(backend_ptr, db_result, dbcommand::next_sequence_number()); + let trimmed = dbcommand::trim_and_cache_remaining(backend_ptr, db_result, dbcommand::next_sequence_number()); let mut out_bytes = Vec::new(); trimmed.encode(&mut out_bytes).unwrap(); @@ -499,6 +491,14 @@ pub unsafe extern "C" fn Java_net_ankiweb_rsdroid_NativeMethods_getColumnNames( } } +#[no_mangle] +pub unsafe extern "C" fn Java_net_ankiweb_rsdroid_NativeMethods_setDbPageSize( + _: JNIEnv, + _: JClass, + page_size: jlong) { + dbcommand::set_max_page_size(page_size as usize); +} + unsafe fn to_backend(ptr: jlong) -> &'static mut AnkiDroidBackend { // TODO: This is not unwindable, but we can't hard-crash as Android won't send it to ACRA // As long as the FatalError is sent below, we're OK @@ -513,23 +513,6 @@ fn panic_to_bytes(env: JNIEnv , s: &(dyn Any + Send), i18n: &I18n) -> jbyteArray env.byte_array_from_slice(bytes.as_slice()).unwrap() } -/** -Store the data in the cache if there's more than DB_COMMAND_NUM_ROWS.
-Returns: The data capped to DB_COMMAND_NUM_ROWS -*/ -fn trim_and_cache_remaining(backend_ptr: i64, values: DbResult, sequence_number: i32) -> DbResponse { - let row_count = values.rows.len() as i32; - if values.rows.len() > DB_COMMAND_NUM_ROWS { - let result = values.rows.iter().take(DB_COMMAND_NUM_ROWS).cloned().collect(); - let to_store = DbResponse { result: Some(values), sequence_number, row_count }; - dbcommand::insert_cache(backend_ptr, to_store); - - DbResponse { result: Some(DbResult { rows: result }), sequence_number, row_count } - } else { - DbResponse { result: Some(values), sequence_number, row_count } - } -} - fn panic_to_anki_error(s: &(dyn Any + Send)) -> AnkiError { if let Some(msg) = s.downcast_ref::(){ AnkiError::FatalError { @@ -540,5 +523,4 @@ fn panic_to_anki_error(s: &(dyn Any + Send)) -> AnkiError { info: "panic with no info".to_string() } } -} - +} \ No newline at end of file