Skip to content

Commit

Permalink
fix: also consider 'missing' symbols in the mutation action. Bugfix w…
Browse files Browse the repository at this point in the history
…here Position invariant was broken because of 'missing' symbols
  • Loading branch information
Taepper committed Jan 31, 2024
1 parent 6e15204 commit fab72a6
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 121 deletions.
8 changes: 7 additions & 1 deletion include/silo/query_engine/actions/mutations.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ class Mutations : public Action {
static std::unordered_map<std::string, Mutations<SymbolType>::PrefilteredBitmaps>
preFilterBitmaps(const silo::Database& database, std::vector<OperatorResult>& bitmap_filter);

static void addMutationsCountsForPosition(
static void addPositionToMutationCountsForMixedBitmaps(
uint32_t position,
const PrefilteredBitmaps& bitmaps_to_evaluate,
SymbolMap<SymbolType, std::vector<uint32_t>>& count_of_mutations_per_position
);

static void addPositionToMutationCountsForFullBitmaps(
uint32_t position,
const PrefilteredBitmaps& bitmaps_to_evaluate,
SymbolMap<SymbolType, std::vector<uint32_t>>& count_of_mutations_per_position
Expand Down
4 changes: 3 additions & 1 deletion include/silo/storage/database_partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ class DatabasePartition {

void validate() const;

void flipBitmaps();
void optimizeBitmaps();
void optimizeNucleotideBitmaps();
void optimizeAminoAcidBitmaps();

[[nodiscard]] const std::vector<preprocessing::PartitionChunk>& getChunks() const;

Expand Down
2 changes: 1 addition & 1 deletion src/silo/preprocessing/preprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ Database Preprocessor::buildDatabase(
);
}
}
database.partitions.at(partition_index).flipBitmaps();
database.partitions.at(partition_index).optimizeBitmaps();
SPDLOG_INFO("build - finished sequences for partition {}", partition_index);
}
}
Expand Down
102 changes: 60 additions & 42 deletions src/silo/query_engine/actions/mutations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,63 +63,73 @@ std::unordered_map<std::string, typename Mutations<SymbolType>::PrefilteredBitma
}

template <typename SymbolType>
void Mutations<SymbolType>::addMutationsCountsForPosition(
void Mutations<SymbolType>::addPositionToMutationCountsForMixedBitmaps(
uint32_t position,
const PrefilteredBitmaps& bitmaps_to_evaluate,
SymbolMap<SymbolType, std::vector<uint32_t>>& count_of_mutations_per_position
) {
for (const auto& [filter, sequence_store_partition] : bitmaps_to_evaluate.bitmaps) {
for (const auto symbol : SymbolType::SYMBOLS) {
if (sequence_store_partition.positions[position].isSymbolDeleted(symbol)) {
const auto& current_position = sequence_store_partition.positions[position];
if (current_position.isSymbolDeleted(symbol)) {
count_of_mutations_per_position[symbol][position] += filter->cardinality();
for (const uint32_t idx : *filter) {
const roaring::Roaring& n_bitmap =
sequence_store_partition.missing_symbol_bitmaps[idx];
if (n_bitmap.contains(position)) {
count_of_mutations_per_position[symbol][position] -= 1;
}
}
continue;
}
if (sequence_store_partition.positions[position].isSymbolFlipped(symbol)) {
count_of_mutations_per_position[symbol][position] += filter->andnot_cardinality(
*sequence_store_partition.positions[position].getBitmap(symbol)
);
} else {
count_of_mutations_per_position[symbol][position] += filter->and_cardinality(
*sequence_store_partition.positions[position].getBitmap(symbol)
);
const uint32_t symbol_count =
current_position.isSymbolFlipped(symbol)
? filter->andnot_cardinality(*current_position.getBitmap(symbol))
: filter->and_cardinality(*current_position.getBitmap(symbol));

count_of_mutations_per_position[symbol][position] += symbol_count;

const auto deleted_symbol = current_position.getDeletedSymbol();
if (deleted_symbol.has_value() && symbol != *deleted_symbol) {
count_of_mutations_per_position[*deleted_symbol][position] -= symbol_count;
}
}
}
}

template <typename SymbolType>
void Mutations<SymbolType>::addPositionToMutationCountsForFullBitmaps(
uint32_t position,
const PrefilteredBitmaps& bitmaps_to_evaluate,
SymbolMap<SymbolType, std::vector<uint32_t>>& count_of_mutations_per_position
) {
// For these partitions, we have full bitmaps. Do not need to bother with AND
// cardinality
for (const auto& [filter, sequence_store_partition] : bitmaps_to_evaluate.full_bitmaps) {
for (const auto symbol : SymbolType::SYMBOLS) {
if (sequence_store_partition.positions[position].isSymbolDeleted(symbol)) {
continue;
}
if (sequence_store_partition.positions[position].isSymbolFlipped(symbol)) {
count_of_mutations_per_position[symbol][position] +=
sequence_store_partition.sequence_count -
sequence_store_partition.positions[position].getBitmap(symbol)->cardinality();
} else {
const auto& current_position = sequence_store_partition.positions[position];
if (current_position.isSymbolDeleted(symbol)) {
count_of_mutations_per_position[symbol][position] +=
sequence_store_partition.positions[position].getBitmap(symbol)->cardinality();
sequence_store_partition.sequence_count;
for (const roaring::Roaring& n_bitmap :
sequence_store_partition.missing_symbol_bitmaps) {
if (n_bitmap.contains(position)) {
count_of_mutations_per_position[symbol][position] -= 1;
}
}
continue;
}
}
}
const uint32_t symbol_count = current_position.isSymbolFlipped(symbol)
? sequence_store_partition.sequence_count -
current_position.getBitmap(symbol)->cardinality()
: current_position.getBitmap(symbol)->cardinality();

for (const auto& [filter, sequence_store_partition] : bitmaps_to_evaluate.bitmaps) {
const auto deleted_symbol = sequence_store_partition.positions[position].getDeletedSymbol();
if (deleted_symbol) {
count_of_mutations_per_position[*deleted_symbol][position] += filter->cardinality();
for (const auto symbol : SymbolType::SYMBOLS) {
count_of_mutations_per_position[*deleted_symbol][position] -=
count_of_mutations_per_position[symbol][position];
}
}
}
for (const auto& [filter, sequence_store_partition] : bitmaps_to_evaluate.full_bitmaps) {
const auto deleted_symbol = sequence_store_partition.positions[position].getDeletedSymbol();
if (deleted_symbol) {
count_of_mutations_per_position[*deleted_symbol][position] +=
sequence_store_partition.sequence_count;
for (const auto symbol : SymbolType::SYMBOLS) {
count_of_mutations_per_position[symbol][position] += symbol_count;

const auto deleted_symbol = current_position.getDeletedSymbol();
if (deleted_symbol.has_value() && symbol != *deleted_symbol) {
count_of_mutations_per_position[*deleted_symbol][position] -=
count_of_mutations_per_position[symbol][position];
sequence_store_partition.positions[position].getBitmap(symbol)->cardinality();
}
}
}
Expand All @@ -132,20 +142,28 @@ SymbolMap<SymbolType, std::vector<uint32_t>> Mutations<SymbolType>::calculateMut
) {
const size_t sequence_length = sequence_store.reference_sequence.size();

SymbolMap<SymbolType, std::vector<uint32_t>> count_of_mutations_per_position;
SymbolMap<SymbolType, std::vector<uint32_t>> mutation_counts_per_position;
for (const auto symbol : SymbolType::SYMBOLS) {
count_of_mutations_per_position[symbol].resize(sequence_length);
mutation_counts_per_position[symbol].resize(sequence_length);
}
static constexpr int POSITIONS_PER_PROCESS = 300;
tbb::parallel_for(
tbb::blocked_range<uint32_t>(0, sequence_length, /*grain_size=*/POSITIONS_PER_PROCESS),
[&](const auto& local) {
for (uint32_t pos = local.begin(); pos != local.end(); ++pos) {
addMutationsCountsForPosition(pos, bitmap_filter, count_of_mutations_per_position);
addPositionToMutationCountsForMixedBitmaps(
pos, bitmap_filter, mutation_counts_per_position
);
addPositionToMutationCountsForFullBitmaps(
pos, bitmap_filter, mutation_counts_per_position
);
correctDeletedSymbolCountForMissingSymbols(
pos, bitmap_filter, mutation_counts_per_position
);
}
}
);
return count_of_mutations_per_position;
return mutation_counts_per_position;
}

template <typename SymbolType>
Expand Down
9 changes: 6 additions & 3 deletions src/silo/query_engine/filter_expressions/aa_symbol_equals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
#include <nlohmann/json.hpp>

#include "silo/common/aa_symbols.h"
#include "silo/query_engine/filter_expressions/and.h"
#include "silo/query_engine/filter_expressions/expression.h"
#include "silo/query_engine/filter_expressions/or.h"
#include "silo/query_engine/filter_expressions/negation.h"
#include "silo/query_engine/operators/bitmap_selection.h"
#include "silo/query_engine/operators/complement.h"
#include "silo/query_engine/operators/index_scan.h"
Expand Down Expand Up @@ -78,10 +79,12 @@ std::unique_ptr<silo::query_engine::operators::Operator> AASymbolEquals::compile
symbols.end(),
std::back_inserter(symbol_filters),
[&](AminoAcid::Symbol symbol) {
return std::make_unique<AASymbolEquals>(aa_sequence_name, position, symbol);
return std::make_unique<Negation>(
std::make_unique<AASymbolEquals>(aa_sequence_name, position, symbol)
);
}
);
return Or(std::move(symbol_filters)).compile(database, database_partition, NONE);
return And(std::move(symbol_filters)).compile(database, database_partition, NONE);
}
return std::make_unique<operators::IndexScan>(
aa_store_partition.getBitmap(position, aa_symbol), database_partition.sequence_count
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "silo/common/nucleotide_symbols.h"
#include "silo/config/database_config.h"
#include "silo/database.h"
#include "silo/query_engine/filter_expressions/and.h"
#include "silo/query_engine/filter_expressions/expression.h"
#include "silo/query_engine/filter_expressions/negation.h"
#include "silo/query_engine/filter_expressions/or.h"
#include "silo/query_engine/operators/bitmap_selection.h"
#include "silo/query_engine/operators/complement.h"
Expand Down Expand Up @@ -124,8 +126,7 @@ std::unique_ptr<silo::query_engine::operators::Operator> NucleotideSymbolEquals:
);
}
);
return std::make_unique<Or>(std::move(symbol_filters))
->compile(database, database_partition, NONE);
return Or(std::move(symbol_filters)).compile(database, database_partition, NONE);
}
if (nucleotide_symbol == Nucleotide::SYMBOL_MISSING) {
SPDLOG_TRACE(
Expand Down Expand Up @@ -170,12 +171,12 @@ std::unique_ptr<silo::query_engine::operators::Operator> NucleotideSymbolEquals:
symbols.end(),
std::back_inserter(symbol_filters),
[&](Nucleotide::Symbol symbol) {
return std::make_unique<NucleotideSymbolEquals>(
return std::make_unique<Negation>(std::make_unique<NucleotideSymbolEquals>(
nuc_sequence_name_or_default, position, symbol
);
));
}
);
return Or(std::move(symbol_filters)).compile(database, database_partition, NONE);
return And(std::move(symbol_filters)).compile(database, database_partition, NONE);
}
SPDLOG_TRACE(
"Filtering for symbol '{}' at position {}",
Expand Down
23 changes: 15 additions & 8 deletions src/silo/storage/database_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ void DatabasePartition::validateMetadataColumns() const {
}
}

// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void DatabasePartition::flipBitmaps() {
void DatabasePartition::optimizeBitmaps() {
optimizeNucleotideBitmaps();
optimizeAminoAcidBitmaps();
}

void DatabasePartition::optimizeNucleotideBitmaps() {
for (auto& [_, seq_store] : nuc_sequences) {
tbb::enumerable_thread_specific<decltype(seq_store.indexing_differences_to_reference_sequence
)>
Expand All @@ -177,21 +181,24 @@ void DatabasePartition::flipBitmaps() {
}
}
}
}

void DatabasePartition::optimizeAminoAcidBitmaps() {
for (auto& [_, aa_store] : aa_sequences) {
tbb::enumerable_thread_specific<decltype(aa_store.indexing_differences_to_reference_sequence)>
flipped_bitmaps;
index_changes_to_reference;

auto& positions = aa_store.positions;
tbb::parallel_for(tbb::blocked_range<uint32_t>(0, positions.size()), [&](const auto& local) {
auto& local_flipped_bitmaps = flipped_bitmaps.local();
auto& local_index_changes = index_changes_to_reference.local();
for (auto position = local.begin(); position != local.end(); ++position) {
auto flipped_symbol = positions[position].deleteMostNumerousBitmap(sequence_count);
if (flipped_symbol.has_value()) {
local_flipped_bitmaps.emplace_back(position, *flipped_symbol);
auto symbol_changed = positions[position].deleteMostNumerousBitmap(sequence_count);
if (symbol_changed.has_value()) {
local_index_changes.emplace_back(position, *symbol_changed);
}
}
});
for (const auto& local : flipped_bitmaps) {
for (const auto& local : index_changes_to_reference) {
for (const auto& element : local) {
aa_store.indexing_differences_to_reference_sequence.emplace_back(element);
}
Expand Down
52 changes: 20 additions & 32 deletions src/silo/storage/position.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,28 @@ template <typename SymbolType>
std::optional<typename SymbolType::Symbol> silo::Position<SymbolType>::getHighestCardinalitySymbol(
uint32_t sequence_count
) {
if (symbol_whose_bitmap_is_deleted.has_value()) {
throw std::runtime_error(fmt::format(
"Symbol '{}' is currently deleted. Cannot restore it implicitly and cannot calculate its "
"cardinality as we do not have information about missing symbols",
SymbolType::symbolToChar(*symbol_whose_bitmap_is_deleted)
));
}

std::optional<typename SymbolType::Symbol> max_symbol = std::nullopt;
uint32_t max_count = 0;

uint32_t count_sum = 0;

for (const auto& symbol : SymbolType::SYMBOLS) {
roaring::Roaring& bitmap = bitmaps[symbol];
bitmap.runOptimize();
bitmap.shrinkToFit();
const uint32_t count =
isSymbolFlipped(symbol) ? sequence_count - bitmap.cardinality() : bitmap.cardinality();
count_sum += count;
if (count > max_count) {
max_symbol = symbol;
max_count = count;
}
}
if (symbol_whose_bitmap_is_deleted.has_value()) {
if (sequence_count - count_sum > max_count) {
return symbol_whose_bitmap_is_deleted;
}
}
return max_symbol;
}

Expand All @@ -72,16 +72,10 @@ std::optional<typename SymbolType::Symbol> silo::Position<SymbolType>::flipMostN
uint32_t sequence_count
) {
if (symbol_whose_bitmap_is_deleted.has_value()) {
const auto missing_symbol = symbol_whose_bitmap_is_deleted.value();
for (const auto& symbol : SymbolType::SYMBOLS) {
if (symbol != missing_symbol) {
bitmaps[missing_symbol] |= bitmaps.at(symbol);
}
}
bitmaps[missing_symbol].flip(0, sequence_count);
bitmaps[missing_symbol].runOptimize();
bitmaps[missing_symbol].shrinkToFit();
symbol_whose_bitmap_is_deleted = std::nullopt;
throw std::runtime_error(fmt::format(
"Symbol '{}' is currently deleted. Cannot restore it implicitly",
SymbolType::symbolToChar(*symbol_whose_bitmap_is_deleted)
));
}

std::optional<typename SymbolType::Symbol> max_symbol =
Expand All @@ -108,6 +102,12 @@ template <typename SymbolType>
std::optional<typename SymbolType::Symbol> silo::Position<SymbolType>::deleteMostNumerousBitmap(
uint32_t sequence_count
) {
if (symbol_whose_bitmap_is_deleted.has_value()) {
throw std::runtime_error(fmt::format(
"Symbol '{}' is currently deleted. Cannot restore it implicitly",
SymbolType::symbolToChar(*symbol_whose_bitmap_is_deleted)
));
}
if (symbol_whose_bitmap_is_flipped.has_value()) {
bitmaps[*symbol_whose_bitmap_is_flipped].flip(0, sequence_count);
bitmaps[*symbol_whose_bitmap_is_flipped].runOptimize();
Expand All @@ -118,20 +118,8 @@ std::optional<typename SymbolType::Symbol> silo::Position<SymbolType>::deleteMos
std::optional<typename SymbolType::Symbol> max_symbol =
getHighestCardinalitySymbol(sequence_count);

if (max_symbol != symbol_whose_bitmap_is_deleted) {
if (symbol_whose_bitmap_is_deleted.has_value()) {
for (const auto& symbol : SymbolType::SYMBOLS) {
if (symbol != *symbol_whose_bitmap_is_deleted) {
bitmaps[*symbol_whose_bitmap_is_deleted] |= bitmaps.at(symbol);
}
}
bitmaps[*symbol_whose_bitmap_is_deleted].flip(0, sequence_count);
bitmaps[*symbol_whose_bitmap_is_deleted].runOptimize();
bitmaps[*symbol_whose_bitmap_is_deleted].shrinkToFit();
}
if (max_symbol.has_value()) {
bitmaps[*max_symbol] = roaring::Roaring();
}
if (max_symbol.has_value()) {
bitmaps[*max_symbol] = roaring::Roaring();
symbol_whose_bitmap_is_deleted = max_symbol;
return symbol_whose_bitmap_is_deleted;
}
Expand Down
Loading

0 comments on commit fab72a6

Please sign in to comment.