diff --git a/include/silo/query_engine/actions/action.h b/include/silo/query_engine/actions/action.h index 21d0dd23b..7baf16f18 100644 --- a/include/silo/query_engine/actions/action.h +++ b/include/silo/query_engine/actions/action.h @@ -28,6 +28,7 @@ class Action { std::vector order_by_fields; std::optional limit; std::optional offset; + std::optional randomize_seed; void applySort(QueryResult& result) const; void applyOffsetAndLimit(QueryResult& result) const; @@ -46,7 +47,8 @@ class Action { void setOrdering( const std::vector& order_by_fields, std::optional limit, - std::optional offset + std::optional offset, + std::optional randomize_seed ); [[nodiscard]] virtual QueryResult executeAndOrder( @@ -55,6 +57,12 @@ class Action { ) const; }; +std::optional parseLimit(const nlohmann::json& json); + +std::optional parseOffset(const nlohmann::json& json); + +std::optional parseRandomizeSeed(const nlohmann::json& json); + // NOLINTNEXTLINE(readability-identifier-naming) void from_json(const nlohmann::json& json, std::unique_ptr& action); diff --git a/include/silo/query_engine/actions/tuple.h b/include/silo/query_engine/actions/tuple.h index bc7d815b0..a781e5d06 100644 --- a/include/silo/query_engine/actions/tuple.h +++ b/include/silo/query_engine/actions/tuple.h @@ -52,7 +52,8 @@ class Tuple { static Comparator getComparator( const std::vector& columns_metadata, - const std::vector& order_by_fields + const std::vector& order_by_fields, + const std::optional& randomize_seed ); bool operator==(const Tuple& other) const; diff --git a/src/silo/query_engine/actions/action.cpp b/src/silo/query_engine/actions/action.cpp index 33abf23f1..e9b7c5439 100644 --- a/src/silo/query_engine/actions/action.cpp +++ b/src/silo/query_engine/actions/action.cpp @@ -2,8 +2,10 @@ #include #include +#include #include #include +#include #include #include @@ -51,6 +53,12 @@ void Action::applySort(QueryResult& result) const { static_cast(limit.value_or(result_vector.size()) + offset.value_or(0UL)), result_vector.size() ); + if (randomize_seed) { + std::default_random_engine rng(*randomize_seed); + std::shuffle( + result_vector.begin(), result_vector.begin() + static_cast(end_of_sort), rng + ); + } if (!order_by_fields.empty()) { if (end_of_sort < result_vector.size()) { std::partial_sort( @@ -94,11 +102,13 @@ void Action::applyOffsetAndLimit(QueryResult& result) const { void Action::setOrdering( const std::vector& order_by_fields_, std::optional limit_, - std::optional offset_ + std::optional offset_, + std::optional randomize_seed_ ) { order_by_fields = order_by_fields_; limit = limit_; offset = offset_; + randomize_seed = randomize_seed_; } QueryResult Action::executeAndOrder( @@ -140,6 +150,45 @@ void from_json(const nlohmann::json& json, OrderByField& field) { field = {field_name, order_string == "ascending"}; } +std::optional parseLimit(const nlohmann::json& json) { + CHECK_SILO_QUERY( + !json.contains("limit") || json["limit"].is_number_unsigned(), + "If the action contains a limit, it must be a non-negative number" + ) + return json.contains("limit") ? std::optional(json["limit"].get()) + : std::nullopt; +} + +std::optional parseOffset(const nlohmann::json& json) { + CHECK_SILO_QUERY( + !json.contains("offset") || json["offset"].is_number_unsigned(), + "If the action contains an offset, it must be a non-negative number" + ) + return json.contains("offset") ? std::optional(json["offset"].get()) + : std::nullopt; +} + +std::optional parseRandomizeSeed(const nlohmann::json& json) { + if (json.contains("randomize")) { + if (json["randomize"].is_boolean()) { + if (json["randomize"].get()) { + const uint32_t time_based_seed = + std::chrono::system_clock::now().time_since_epoch().count(); + return time_based_seed; + } + return std::nullopt; + } + CHECK_SILO_QUERY( + json["randomize"].is_object() && json["randomize"].contains("seed") && + json["randomize"]["seed"].is_number_unsigned(), + "If the action contains 'randomize', it must be either a boolean or an object " + "containing an unsigned 'seed'" + ) + return json["randomize"]["seed"].get(); + } + return std::nullopt; +} + // NOLINTNEXTLINE(readability-identifier-naming) void from_json(const nlohmann::json& json, std::unique_ptr& action) { CHECK_SILO_QUERY(json.contains("type"), "The field 'type' is required in any action") @@ -171,19 +220,14 @@ void from_json(const nlohmann::json& json, std::unique_ptr& action) { auto order_by_fields = json.contains("orderByFields") ? json["orderByFields"].get>() : std::vector(); - CHECK_SILO_QUERY( - !json.contains("limit") || json["limit"].is_number_unsigned(), - "If the action contains a limit, it must be a non-negative number" - ) CHECK_SILO_QUERY( !json.contains("offset") || json["offset"].is_number_unsigned(), "If the action contains an offset, it must be a non-negative number" ) - auto limit = json.contains("limit") ? std::optional(json["limit"].get()) - : std::nullopt; - auto offset = json.contains("offset") ? std::optional(json["offset"].get()) - : std::nullopt; - action->setOrdering(order_by_fields, limit, offset); + auto limit = parseLimit(json); + auto offset = parseOffset(json); + auto randomize_seed = parseRandomizeSeed(json); + action->setOrdering(order_by_fields, limit, offset, randomize_seed); } } // namespace silo::query_engine::actions diff --git a/src/silo/query_engine/actions/details.cpp b/src/silo/query_engine/actions/details.cpp index 703e9fe06..b2c7f4159 100644 --- a/src/silo/query_engine/actions/details.cpp +++ b/src/silo/query_engine/actions/details.cpp @@ -1,6 +1,7 @@ #include "silo/query_engine/actions/details.h" #include +#include #include #include @@ -197,14 +198,16 @@ QueryResult Details::executeAndOrder( tuples = produceSortedTuplesWithLimit( tuple_factories, bitmap_filter, - Tuple::getComparator(field_metadata, order_by_fields), + Tuple::getComparator(field_metadata, order_by_fields, randomize_seed), limit.value() + offset.value_or(0) ); } else { tuples = produceAllTuples(tuple_factories, bitmap_filter); - if (!order_by_fields.empty()) { + if (!order_by_fields.empty() || randomize_seed) { std::sort( - tuples.begin(), tuples.end(), Tuple::getComparator(field_metadata, order_by_fields) + tuples.begin(), + tuples.end(), + Tuple::getComparator(field_metadata, order_by_fields, randomize_seed) ); } } diff --git a/src/silo/query_engine/actions/tuple.cpp b/src/silo/query_engine/actions/tuple.cpp index 88386d083..24b0b7d4f 100644 --- a/src/silo/query_engine/actions/tuple.cpp +++ b/src/silo/query_engine/actions/tuple.cpp @@ -8,6 +8,8 @@ #include #include +#include + #include "silo/common/date.h" #include "silo/common/string.h" #include "silo/common/types.h" @@ -363,10 +365,31 @@ std::vector Tuple::getCompareFields( Tuple::Comparator Tuple::getComparator( const std::vector& columns_metadata, - const std::vector& order_by_fields + const std::vector& order_by_fields, + const std::optional& randomize_seed ) { auto tuple_field_comparators = actions::Tuple::getCompareFields(columns_metadata, order_by_fields); + if (randomize_seed) { + const size_t seed = *randomize_seed; + return [tuple_field_comparators, seed](const Tuple& tuple1, const Tuple& tuple2) { + if (tuple1.compareLess(tuple2, tuple_field_comparators)) { + return true; + } + if (tuple2.compareLess(tuple1, tuple_field_comparators)) { + return false; + } + size_t random_number1 = seed; + size_t random_number2 = seed; + boost::hash_combine( + random_number1, std::hash()(tuple1) + ); + boost::hash_combine( + random_number2, std::hash()(tuple2) + ); + return random_number1 < random_number2; + }; + } return [tuple_field_comparators](const Tuple& tuple1, const Tuple& tuple2) { return tuple1.compareLess(tuple2, tuple_field_comparators); }; diff --git a/src/silo/query_engine/actions/tuple.test.cpp b/src/silo/query_engine/actions/tuple.test.cpp index 057180d56..716ff6f5d 100644 --- a/src/silo/query_engine/actions/tuple.test.cpp +++ b/src/silo/query_engine/actions/tuple.test.cpp @@ -211,7 +211,7 @@ TEST(Tuple, comparesFieldsCorrectly) { std::vector order_by_fields; order_by_fields.push_back({"dummy_indexed_string_column", true}); const Tuple::Comparator under_test = - Tuple::getComparator(columns.second.metadata, order_by_fields); + Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt); ASSERT_FALSE(under_test(tuple0a, tuple0b)); ASSERT_FALSE(under_test(tuple0b, tuple0a)); @@ -233,7 +233,7 @@ TEST(Tuple, comparesFieldsCorrectly) { order_by_fields.clear(); const Tuple::Comparator under_test2 = - Tuple::getComparator(columns.second.metadata, order_by_fields); + Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt); ASSERT_FALSE(under_test2(tuple0a, tuple0b)); ASSERT_FALSE(under_test2(tuple0b, tuple0a)); @@ -255,7 +255,7 @@ TEST(Tuple, comparesFieldsCorrectly) { order_by_fields.push_back({"dummy_date_column", true}); const Tuple::Comparator under_test3 = - Tuple::getComparator(columns.second.metadata, order_by_fields); + Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt); ASSERT_FALSE(under_test3(tuple0a, tuple0b)); ASSERT_FALSE(under_test3(tuple0b, tuple0a)); @@ -277,7 +277,7 @@ TEST(Tuple, comparesFieldsCorrectly) { order_by_fields.push_back({"dummy_string_column", false}); const Tuple::Comparator under_test4 = - Tuple::getComparator(columns.second.metadata, order_by_fields); + Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt); ASSERT_FALSE(under_test4(tuple0a, tuple0b)); ASSERT_FALSE(under_test4(tuple0b, tuple0a)); diff --git a/src/silo/test/randomize.test.cpp b/src/silo/test/randomize.test.cpp new file mode 100644 index 000000000..8aefd584d --- /dev/null +++ b/src/silo/test/randomize.test.cpp @@ -0,0 +1,85 @@ +#include + +#include + +#include "silo/test/query_fixture.test.h" + +using silo::ReferenceGenomes; +using silo::config::DatabaseConfig; +using silo::config::ValueType; +using silo::test::QueryTestData; +using silo::test::QueryTestScenario; + +const std::vector DATA = { + {{"metadata", {{"key", "id1"}, {"col", "A"}}}, + {"alignedNucleotideSequences", {{"segment1", nullptr}}}, + {"unalignedNucleotideSequences", {{"segment1", nullptr}}}, + {"alignedAminoAcidSequences", {{"gene1", nullptr}}}}, + {{"metadata", {{"key", "id2"}, {"col", "A"}}}, + {"alignedNucleotideSequences", {{"segment1", nullptr}}}, + {"unalignedNucleotideSequences", {{"segment1", nullptr}}}, + {"alignedAminoAcidSequences", {{"gene1", nullptr}}}}, + {{"metadata", {{"key", "id3"}, {"col", "A"}}}, + {"alignedNucleotideSequences", {{"segment1", nullptr}}}, + {"unalignedNucleotideSequences", {{"segment1", nullptr}}}, + {"alignedAminoAcidSequences", {{"gene1", nullptr}}}}, + {{"metadata", {{"key", "id4"}, {"col", "A"}}}, + {"alignedNucleotideSequences", {{"segment1", nullptr}}}, + {"unalignedNucleotideSequences", {{"segment1", nullptr}}}, + {"alignedAminoAcidSequences", {{"gene1", nullptr}}}}, + {{"metadata", {{"key", "id5"}, {"col", "A"}}}, + {"alignedNucleotideSequences", {{"segment1", nullptr}}}, + {"unalignedNucleotideSequences", {{"segment1", nullptr}}}, + {"alignedAminoAcidSequences", {{"gene1", nullptr}}}} +}; + +const auto DATABASE_CONFIG = DatabaseConfig{ + "segment1", + {"dummy name", {{"key", ValueType::STRING}, {"col", ValueType::STRING}}, "key"} +}; + +const auto REFERENCE_GENOMES = ReferenceGenomes{ + {{"segment1", "A"}}, + {{"gene1", "*"}}, +}; + +const QueryTestData TEST_DATA{DATA, DATABASE_CONFIG, REFERENCE_GENOMES}; + +const QueryTestScenario RANDOMIZE_SEED = { + "seed1231ProvidedShouldShuffleResults", + {{"action", {{"type", "Details"}, {"fields", {"key"}}, {"randomize", {{"seed", 1231}}}}}, + {"filterExpression", {{"type", "True"}}}}, + {{{"key", "id1"}}, {{"key", "id4"}}, {{"key", "id3"}}, {{"key", "id5"}}, {{"key", "id2"}}} +}; + +const QueryTestScenario RANDOMIZE_SEED_DIFFERENT = { + "seed12312ProvidedShouldShuffleResultsDifferently", + {{"action", {{"type", "Details"}, {"fields", {"key"}}, {"randomize", {{"seed", 12312}}}}}, + {"filterExpression", {{"type", "True"}}}}, + {{{"key", "id2"}}, {{"key", "id1"}}, {{"key", "id4"}}, {{"key", "id5"}}, {{"key", "id3"}}} +}; + +const QueryTestScenario EXPLICIT_DO_NOT_RANDOMIZE = { + "explicitlyDoNotRandomize", + {{"action", {{"type", "Details"}, {"fields", {"key"}}, {"randomize", false}}}, + {"filterExpression", {{"type", "True"}}}}, + {{{"key", "id1"}}, {{"key", "id2"}}, {{"key", "id3"}}, {{"key", "id4"}}, {{"key", "id5"}}} +}; + +const QueryTestScenario AGGREGATE = { + "aggregateRandomize", + {{"action", + {{"type", "Aggregated"}, {"groupByFields", {"key"}}, {"randomize", {{"seed", 12321}}}}}, + {"filterExpression", {{"type", "True"}}}}, + {{{"count", 1}, {"key", "id1"}}, + {{"count", 1}, {"key", "id4"}}, + {{"count", 1}, {"key", "id2"}}, + {{"count", 1}, {"key", "id3"}}, + {{"count", 1}, {"key", "id5"}}} +}; + +QUERY_TEST( + RandomizeTest, + TEST_DATA, + ::testing::Values(RANDOMIZE_SEED, RANDOMIZE_SEED_DIFFERENT, EXPLICIT_DO_NOT_RANDOMIZE, AGGREGATE) +);