Skip to content

Commit

Permalink
feat: reintroduce randomize for all query actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Taepper committed Feb 27, 2024
1 parent bddc6e9 commit 166045c
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 20 deletions.
10 changes: 9 additions & 1 deletion include/silo/query_engine/actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Action {
std::vector<OrderByField> order_by_fields;
std::optional<uint32_t> limit;
std::optional<uint32_t> offset;
std::optional<uint32_t> randomize_seed;

void applySort(QueryResult& result) const;
void applyOffsetAndLimit(QueryResult& result) const;
Expand All @@ -46,7 +47,8 @@ class Action {
void setOrdering(
const std::vector<OrderByField>& order_by_fields,
std::optional<uint32_t> limit,
std::optional<uint32_t> offset
std::optional<uint32_t> offset,
std::optional<uint32_t> randomize_seed
);

[[nodiscard]] virtual QueryResult executeAndOrder(
Expand All @@ -55,6 +57,12 @@ class Action {
) const;
};

std::optional<uint32_t> parseLimit(const nlohmann::json& json);

std::optional<uint32_t> parseOffset(const nlohmann::json& json);

std::optional<uint32_t> parseRandomizeSeed(const nlohmann::json& json);

// NOLINTNEXTLINE(readability-identifier-naming)
void from_json(const nlohmann::json& json, std::unique_ptr<Action>& action);

Expand Down
3 changes: 2 additions & 1 deletion include/silo/query_engine/actions/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class Tuple {

static Comparator getComparator(
const std::vector<silo::storage::ColumnMetadata>& columns_metadata,
const std::vector<OrderByField>& order_by_fields
const std::vector<OrderByField>& order_by_fields,
const std::optional<uint32_t>& randomize_seed
);

bool operator==(const Tuple& other) const;
Expand Down
64 changes: 54 additions & 10 deletions src/silo/query_engine/actions/action.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

#include <algorithm>
#include <cctype>
#include <chrono>
#include <map>
#include <memory>
#include <random>
#include <utility>

#include <nlohmann/json.hpp>
Expand Down Expand Up @@ -51,6 +53,12 @@ void Action::applySort(QueryResult& result) const {
static_cast<size_t>(limit.value_or(result_vector.size()) + offset.value_or(0UL)),
result_vector.size()
);
if (randomize_seed) {
std::default_random_engine rng(*randomize_seed);
std::shuffle(
result_vector.begin(), result_vector.begin() + static_cast<int64_t>(end_of_sort), rng
);
}
if (!order_by_fields.empty()) {
if (end_of_sort < result_vector.size()) {
std::partial_sort(
Expand Down Expand Up @@ -94,11 +102,13 @@ void Action::applyOffsetAndLimit(QueryResult& result) const {
void Action::setOrdering(
const std::vector<OrderByField>& order_by_fields_,
std::optional<uint32_t> limit_,
std::optional<uint32_t> offset_
std::optional<uint32_t> offset_,
std::optional<uint32_t> randomize_seed_
) {
order_by_fields = order_by_fields_;
limit = limit_;
offset = offset_;
randomize_seed = randomize_seed_;
}

QueryResult Action::executeAndOrder(
Expand Down Expand Up @@ -140,6 +150,45 @@ void from_json(const nlohmann::json& json, OrderByField& field) {
field = {field_name, order_string == "ascending"};
}

std::optional<uint32_t> parseLimit(const nlohmann::json& json) {
CHECK_SILO_QUERY(
!json.contains("limit") || json["limit"].is_number_unsigned(),
"If the action contains a limit, it must be a non-negative number"
)
return json.contains("limit") ? std::optional<uint32_t>(json["limit"].get<uint32_t>())
: std::nullopt;
}

std::optional<uint32_t> parseOffset(const nlohmann::json& json) {
CHECK_SILO_QUERY(
!json.contains("offset") || json["offset"].is_number_unsigned(),
"If the action contains an offset, it must be a non-negative number"
)
return json.contains("offset") ? std::optional<uint32_t>(json["offset"].get<uint32_t>())
: std::nullopt;
}

std::optional<uint32_t> parseRandomizeSeed(const nlohmann::json& json) {
if (json.contains("randomize")) {
if (json["randomize"].is_boolean()) {
if (json["randomize"].get<bool>()) {
const uint32_t time_based_seed =
std::chrono::system_clock::now().time_since_epoch().count();
return time_based_seed;
}
return std::nullopt;
}
CHECK_SILO_QUERY(
json["randomize"].is_object() && json["randomize"].contains("seed") &&
json["randomize"]["seed"].is_number_unsigned(),
"If the action contains 'randomize', it must be either a boolean or an object "
"containing an unsigned 'seed'"
)
return json["randomize"]["seed"].get<uint32_t>();
}
return std::nullopt;
}

// NOLINTNEXTLINE(readability-identifier-naming)
void from_json(const nlohmann::json& json, std::unique_ptr<Action>& action) {
CHECK_SILO_QUERY(json.contains("type"), "The field 'type' is required in any action")
Expand Down Expand Up @@ -171,19 +220,14 @@ void from_json(const nlohmann::json& json, std::unique_ptr<Action>& action) {
auto order_by_fields = json.contains("orderByFields")
? json["orderByFields"].get<std::vector<OrderByField>>()
: std::vector<OrderByField>();
CHECK_SILO_QUERY(
!json.contains("limit") || json["limit"].is_number_unsigned(),
"If the action contains a limit, it must be a non-negative number"
)
CHECK_SILO_QUERY(
!json.contains("offset") || json["offset"].is_number_unsigned(),
"If the action contains an offset, it must be a non-negative number"
)
auto limit = json.contains("limit") ? std::optional<uint32_t>(json["limit"].get<uint32_t>())
: std::nullopt;
auto offset = json.contains("offset") ? std::optional<uint32_t>(json["offset"].get<uint32_t>())
: std::nullopt;
action->setOrdering(order_by_fields, limit, offset);
auto limit = parseLimit(json);
auto offset = parseOffset(json);
auto randomize_seed = parseRandomizeSeed(json);
action->setOrdering(order_by_fields, limit, offset, randomize_seed);
}

} // namespace silo::query_engine::actions
9 changes: 6 additions & 3 deletions src/silo/query_engine/actions/details.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "silo/query_engine/actions/details.h"

#include <algorithm>
#include <random>
#include <utility>

#include <oneapi/tbb/blocked_range.h>
Expand Down Expand Up @@ -197,14 +198,16 @@ QueryResult Details::executeAndOrder(
tuples = produceSortedTuplesWithLimit(
tuple_factories,
bitmap_filter,
Tuple::getComparator(field_metadata, order_by_fields),
Tuple::getComparator(field_metadata, order_by_fields, randomize_seed),
limit.value() + offset.value_or(0)
);
} else {
tuples = produceAllTuples(tuple_factories, bitmap_filter);
if (!order_by_fields.empty()) {
if (!order_by_fields.empty() || randomize_seed) {
std::sort(
tuples.begin(), tuples.end(), Tuple::getComparator(field_metadata, order_by_fields)
tuples.begin(),
tuples.end(),
Tuple::getComparator(field_metadata, order_by_fields, randomize_seed)
);
}
}
Expand Down
25 changes: 24 additions & 1 deletion src/silo/query_engine/actions/tuple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <string_view>
#include <utility>

#include <boost/container_hash/hash.hpp>

#include "silo/common/date.h"
#include "silo/common/string.h"
#include "silo/common/types.h"
Expand Down Expand Up @@ -363,10 +365,31 @@ std::vector<Tuple::ComparatorField> Tuple::getCompareFields(

Tuple::Comparator Tuple::getComparator(
const std::vector<silo::storage::ColumnMetadata>& columns_metadata,
const std::vector<OrderByField>& order_by_fields
const std::vector<OrderByField>& order_by_fields,
const std::optional<uint32_t>& randomize_seed
) {
auto tuple_field_comparators =
actions::Tuple::getCompareFields(columns_metadata, order_by_fields);
if (randomize_seed) {
const size_t seed = *randomize_seed;
return [tuple_field_comparators, seed](const Tuple& tuple1, const Tuple& tuple2) {
if (tuple1.compareLess(tuple2, tuple_field_comparators)) {
return true;
}
if (tuple2.compareLess(tuple1, tuple_field_comparators)) {
return false;
}
size_t random_number1 = seed;
size_t random_number2 = seed;
boost::hash_combine(
random_number1, std::hash<silo::query_engine::actions::Tuple>()(tuple1)
);
boost::hash_combine(
random_number2, std::hash<silo::query_engine::actions::Tuple>()(tuple2)
);
return random_number1 < random_number2;
};
}
return [tuple_field_comparators](const Tuple& tuple1, const Tuple& tuple2) {
return tuple1.compareLess(tuple2, tuple_field_comparators);
};
Expand Down
8 changes: 4 additions & 4 deletions src/silo/query_engine/actions/tuple.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ TEST(Tuple, comparesFieldsCorrectly) {
std::vector<silo::query_engine::actions::OrderByField> order_by_fields;
order_by_fields.push_back({"dummy_indexed_string_column", true});
const Tuple::Comparator under_test =
Tuple::getComparator(columns.second.metadata, order_by_fields);
Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt);

ASSERT_FALSE(under_test(tuple0a, tuple0b));
ASSERT_FALSE(under_test(tuple0b, tuple0a));
Expand All @@ -233,7 +233,7 @@ TEST(Tuple, comparesFieldsCorrectly) {

order_by_fields.clear();
const Tuple::Comparator under_test2 =
Tuple::getComparator(columns.second.metadata, order_by_fields);
Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt);

ASSERT_FALSE(under_test2(tuple0a, tuple0b));
ASSERT_FALSE(under_test2(tuple0b, tuple0a));
Expand All @@ -255,7 +255,7 @@ TEST(Tuple, comparesFieldsCorrectly) {

order_by_fields.push_back({"dummy_date_column", true});
const Tuple::Comparator under_test3 =
Tuple::getComparator(columns.second.metadata, order_by_fields);
Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt);

ASSERT_FALSE(under_test3(tuple0a, tuple0b));
ASSERT_FALSE(under_test3(tuple0b, tuple0a));
Expand All @@ -277,7 +277,7 @@ TEST(Tuple, comparesFieldsCorrectly) {

order_by_fields.push_back({"dummy_string_column", false});
const Tuple::Comparator under_test4 =
Tuple::getComparator(columns.second.metadata, order_by_fields);
Tuple::getComparator(columns.second.metadata, order_by_fields, std::nullopt);

ASSERT_FALSE(under_test4(tuple0a, tuple0b));
ASSERT_FALSE(under_test4(tuple0b, tuple0a));
Expand Down
85 changes: 85 additions & 0 deletions src/silo/test/randomize.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <nlohmann/json.hpp>

#include <optional>

#include "silo/test/query_fixture.test.h"

using silo::ReferenceGenomes;
using silo::config::DatabaseConfig;
using silo::config::ValueType;
using silo::test::QueryTestData;
using silo::test::QueryTestScenario;

const std::vector<nlohmann::json> DATA = {
{{"metadata", {{"key", "id1"}, {"col", "A"}}},
{"alignedNucleotideSequences", {{"segment1", nullptr}}},
{"unalignedNucleotideSequences", {{"segment1", nullptr}}},
{"alignedAminoAcidSequences", {{"gene1", nullptr}}}},
{{"metadata", {{"key", "id2"}, {"col", "A"}}},
{"alignedNucleotideSequences", {{"segment1", nullptr}}},
{"unalignedNucleotideSequences", {{"segment1", nullptr}}},
{"alignedAminoAcidSequences", {{"gene1", nullptr}}}},
{{"metadata", {{"key", "id3"}, {"col", "A"}}},
{"alignedNucleotideSequences", {{"segment1", nullptr}}},
{"unalignedNucleotideSequences", {{"segment1", nullptr}}},
{"alignedAminoAcidSequences", {{"gene1", nullptr}}}},
{{"metadata", {{"key", "id4"}, {"col", "A"}}},
{"alignedNucleotideSequences", {{"segment1", nullptr}}},
{"unalignedNucleotideSequences", {{"segment1", nullptr}}},
{"alignedAminoAcidSequences", {{"gene1", nullptr}}}},
{{"metadata", {{"key", "id5"}, {"col", "A"}}},
{"alignedNucleotideSequences", {{"segment1", nullptr}}},
{"unalignedNucleotideSequences", {{"segment1", nullptr}}},
{"alignedAminoAcidSequences", {{"gene1", nullptr}}}}
};

const auto DATABASE_CONFIG = DatabaseConfig{
"segment1",
{"dummy name", {{"key", ValueType::STRING}, {"col", ValueType::STRING}}, "key"}
};

const auto REFERENCE_GENOMES = ReferenceGenomes{
{{"segment1", "A"}},
{{"gene1", "*"}},
};

const QueryTestData TEST_DATA{DATA, DATABASE_CONFIG, REFERENCE_GENOMES};

const QueryTestScenario RANDOMIZE_SEED = {
"seed1231ProvidedShouldShuffleResults",
{{"action", {{"type", "Details"}, {"fields", {"key"}}, {"randomize", {{"seed", 1231}}}}},
{"filterExpression", {{"type", "True"}}}},
{{{"key", "id1"}}, {{"key", "id4"}}, {{"key", "id3"}}, {{"key", "id5"}}, {{"key", "id2"}}}
};

const QueryTestScenario RANDOMIZE_SEED_DIFFERENT = {
"seed12312ProvidedShouldShuffleResultsDifferently",
{{"action", {{"type", "Details"}, {"fields", {"key"}}, {"randomize", {{"seed", 12312}}}}},
{"filterExpression", {{"type", "True"}}}},
{{{"key", "id2"}}, {{"key", "id1"}}, {{"key", "id4"}}, {{"key", "id5"}}, {{"key", "id3"}}}
};

const QueryTestScenario EXPLICIT_DO_NOT_RANDOMIZE = {
"explicitlyDoNotRandomize",
{{"action", {{"type", "Details"}, {"fields", {"key"}}, {"randomize", false}}},
{"filterExpression", {{"type", "True"}}}},
{{{"key", "id1"}}, {{"key", "id2"}}, {{"key", "id3"}}, {{"key", "id4"}}, {{"key", "id5"}}}
};

const QueryTestScenario AGGREGATE = {
"aggregateRandomize",
{{"action",
{{"type", "Aggregated"}, {"groupByFields", {"key"}}, {"randomize", {{"seed", 12321}}}}},
{"filterExpression", {{"type", "True"}}}},
{{{"count", 1}, {"key", "id1"}},
{{"count", 1}, {"key", "id4"}},
{{"count", 1}, {"key", "id2"}},
{{"count", 1}, {"key", "id3"}},
{{"count", 1}, {"key", "id5"}}}
};

QUERY_TEST(
RandomizeTest,
TEST_DATA,
::testing::Values(RANDOMIZE_SEED, RANDOMIZE_SEED_DIFFERENT, EXPLICIT_DO_NOT_RANDOMIZE, AGGREGATE)
);

0 comments on commit 166045c

Please sign in to comment.