Skip to content

Commit

Permalink
Redesign FST window averaging implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lczech committed Jun 4, 2024
1 parent cdccf24 commit ddfb3fb
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 204 deletions.
17 changes: 11 additions & 6 deletions lib/genesis/population/function/fst_cathedral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "genesis/population/function/fst_cathedral.hpp"

#include "genesis/population/function/window_average.hpp"
#include "genesis/utils/formats/json/document.hpp"

#include <cassert>
Expand Down Expand Up @@ -183,23 +184,27 @@ void fill_fst_cathedral_records_from_processor_(
// Bit hacky, but good enough for now. Then, store the results.
assert( processor.size() == records.size() );
for( size_t i = 0; i < processor.size(); ++i ) {
auto& raw_calc = processor.calculators()[i];
auto fst_calc = dynamic_cast<FstPoolCalculatorUnbiased const*>( raw_calc.get() );
auto const& raw_calc = processor.calculators()[i];
auto const* fst_calc = dynamic_cast<FstPoolCalculatorUnbiased const*>( raw_calc.get() );
if( ! fst_calc ) {
throw std::runtime_error(
"In compute_fst_cathedral_records_for_chromosome(): "
"Invalid FstPoolCalculator that is not FstPoolCalculatorUnbiased"
);
}
if( fst_calc->get_window_average_policy() != WindowAveragePolicy::kSum ) {
throw std::runtime_error(
"In compute_fst_cathedral_records_for_chromosome(): "
"Invalid FstPoolCalculator that is not using WindowAveragePolicy::kSum"
);
}

// Now add the entry for the current calculator to its respective records entry.
// We rely on the amortized complexity here - cannot pre-allocate the size,
// as we do not know how many positions will actually be in the input beforehand.
auto const pis = fst_calc->get_pi_values( 0, processor.get_filter_stats() );
records[i].entries.emplace_back(
position,
fst_calc->get_pi_within(),
fst_calc->get_pi_between(),
fst_calc->get_pi_total()
position, pis.pi_within, pis.pi_between, pis.pi_total
);
}
}
Expand Down
16 changes: 13 additions & 3 deletions lib/genesis/population/function/fst_pool_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@
* @ingroup population
*/

#include "genesis/population/filter/filter_stats.hpp"
#include "genesis/population/filter/filter_status.hpp"
#include "genesis/population/filter/sample_counts_filter.hpp"
#include "genesis/population/filter/variant_filter.hpp"
#include "genesis/population/function/fst_pool_karlsson.hpp"
#include "genesis/population/function/fst_pool_kofler.hpp"
#include "genesis/population/function/fst_pool_unbiased.hpp"
#include "genesis/population/function/fst_pool_processor.hpp"
#include "genesis/population/function/fst_pool_unbiased.hpp"
#include "genesis/population/function/functions.hpp"
#include "genesis/population/function/window_average.hpp"
#include "genesis/population/variant.hpp"
#include "genesis/utils/containers/matrix.hpp"
#include "genesis/utils/containers/transform_iterator.hpp"
Expand Down Expand Up @@ -341,7 +346,10 @@ std::pair<double, double> f_st_pool_unbiased(
}

// Init the calculator.
FstPoolCalculatorUnbiased calc{ p1_poolsize, p2_poolsize };
// For simplicity in this wrapper, we only allow to normalize the pi values per window
// via their sum; in this function, we do not have the additionally needed information
// on the Variant Filter Status statistics anyway. If that is needed, use FstPoolProcessor.
FstPoolCalculatorUnbiased calc{ p1_poolsize, p2_poolsize, WindowAveragePolicy::kSum };

// Iterate the two ranges in parallel. Each iteration is one position in the genome.
// In each iteration, p1_it and p2_it point at SampleCounts objects containing nucleotide counts.
Expand All @@ -362,7 +370,9 @@ std::pair<double, double> f_st_pool_unbiased(
);
}

return calc.get_result_pair();
// Unfortunately, we need dummies here now for the window based counters. They are not used
// with the above WindowAveragePolicy::kSum policy, but need to be provided nonetheless...
return calc.get_result_pair( 0, VariantFilterStats{} );
}

#if __cplusplus >= 201402L
Expand Down
102 changes: 44 additions & 58 deletions lib/genesis/population/function/fst_pool_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,6 @@ class FstPoolProcessor
// Constructors and Rule of Five
// -------------------------------------------------------------------------

/**
* @brief Default constructor.
*
* We always want to make sure that the user provides a WindowAveragePolicy, so using this
* default constructor leads to an unusable instance. We provide it so that dummy processors
* can be constructed, but they have to be replaced by non-default-constructed instances
* befor usage.
*/
FstPoolProcessor() = default;

/**
* @brief Construct a processor.
*
Expand All @@ -94,14 +84,11 @@ class FstPoolProcessor
* is to be used, deactivate by explicitly setting the thread_pool() function here to `nullptr`.
*/
FstPoolProcessor(
WindowAveragePolicy window_average_policy,
std::shared_ptr<utils::ThreadPool> thread_pool = nullptr,
size_t threading_threshold = 4096
)
: avg_policy_( window_average_policy )
, thread_pool_( thread_pool )
: thread_pool_( thread_pool )
, threading_threshold_( threading_threshold )
, is_default_constructed_( false )
{
thread_pool_ = thread_pool ? thread_pool : utils::Options::get().global_thread_pool();
}
Expand Down Expand Up @@ -189,6 +176,7 @@ class FstPoolProcessor

void reset()
{
assert( calculators_.size() == results_.size() );
for( auto& calc : calculators_ ) {
assert( calc );
calc->reset();
Expand All @@ -201,6 +189,10 @@ class FstPoolProcessor

// Also reset the pi vectors to nan.
// If they are not allocated, nothing happens.
auto const res_sz = results_.size();
assert( std::get<0>( results_pi_ ).size() == 0 || std::get<0>( results_pi_ ).size() == res_sz );
assert( std::get<1>( results_pi_ ).size() == 0 || std::get<1>( results_pi_ ).size() == res_sz );
assert( std::get<2>( results_pi_ ).size() == 0 || std::get<2>( results_pi_ ).size() == res_sz );
std::fill(
std::get<0>( results_pi_ ).begin(), std::get<0>( results_pi_ ).end(),
std::numeric_limits<double>::quiet_NaN()
Expand All @@ -219,9 +211,6 @@ class FstPoolProcessor
{
// Check correct usage
assert( sample_pairs_.size() == calculators_.size() );
if( is_default_constructed_ ) {
throw std::domain_error( "Cannot use a default constructed FstPoolProcessor" );
}

// Only process Variants that are passing, but keep track of the ones that did not.
++filter_stats_[variant.status.get()];
Expand Down Expand Up @@ -272,11 +261,19 @@ class FstPoolProcessor
{
assert( results_.size() == calculators_.size() );
for( size_t i = 0; i < results_.size(); ++i ) {
auto const abs_fst = calculators_[i]->get_result();
auto const window_avg_denom = window_average_denominator(
avg_policy_, window_length, filter_stats_, calculators_[i]->get_filter_stats()
);
results_[i] = abs_fst / window_avg_denom;
// We do an ugly dispatch here to treat the special case of the FstPoolCalculatorUnbiased
// class, which needs additional information on the window in order to normalize
// the pi values correctly. The Kofler and Karlsson do not need that, and we want to
// avoid using dummies in these places. So instead, we just do a dispatch here.
// If in the future more calculators are added that need special behaviour,
// we might want to redesign this...
auto const* raw_calc = calculators_[i].get();
auto const* unbiased_calc = dynamic_cast<FstPoolCalculatorUnbiased const*>( raw_calc );
if( unbiased_calc ) {
results_[i] = unbiased_calc->get_result( window_length, filter_stats_ );
} else {
results_[i] = raw_calc->get_result();
}
}
return results_;
}
Expand All @@ -285,9 +282,9 @@ class FstPoolProcessor
* @brief Get lists of all the three intermediate pi values (within, between, total) that
* are part of our unbiased estimators.
*
* This computes the window-averaged values for pi within, pi between, and pi total (in that
* order in the tuple), for each sample pair (order in the three vectors). This uses the same
* window averaging policy as the get_result() function.
* This computes the window-average-corrected values for pi within, pi between, and pi total
* (in that order in the tuple), for each sample pair (order in the three vectors).
* This uses the same window averaging policy as the get_result() function.
*
* This only works when all calculators are of type FstPoolCalculatorUnbiased, and throws an
* exception otherwise. It is merely meant as a convenience function for that particular case.
Expand All @@ -296,31 +293,32 @@ class FstPoolProcessor
{
// Only allocate when someone first calls this.
// Does not do anything afterwards.
auto const result_size = calculators_.size();
std::get<0>( results_pi_ ).resize( result_size );
std::get<1>( results_pi_ ).resize( result_size );
std::get<2>( results_pi_ ).resize( result_size );
auto const res_sz = calculators_.size();
assert( std::get<0>( results_pi_ ).size() == 0 || std::get<0>( results_pi_ ).size() == res_sz );
assert( std::get<1>( results_pi_ ).size() == 0 || std::get<1>( results_pi_ ).size() == res_sz );
assert( std::get<2>( results_pi_ ).size() == 0 || std::get<2>( results_pi_ ).size() == res_sz );
std::get<0>( results_pi_ ).resize( res_sz );
std::get<1>( results_pi_ ).resize( res_sz );
std::get<2>( results_pi_ ).resize( res_sz );

// Get the pi values from all calculators, assuming that they are of the correct type.
for( size_t i = 0; i < result_size; ++i ) {
auto& raw_calc = calculators_[i];
auto cast_calc = dynamic_cast<FstPoolCalculatorUnbiased const*>( raw_calc.get() );
if( ! cast_calc ) {
for( size_t i = 0; i < res_sz; ++i ) {
auto const& raw_calc = calculators_[i];
auto const* unbiased_calc = dynamic_cast<FstPoolCalculatorUnbiased const*>( raw_calc.get() );
if( ! unbiased_calc ) {
throw std::domain_error(
"Can only call FstPoolProcessor::get_pi_vectors() "
"for calculators of type FstPoolCalculatorUnbiased."
);
}

// Get the denominator to use for all averaging.
auto const window_avg_denom = window_average_denominator(
avg_policy_, window_length, filter_stats_, calculators_[i]->get_filter_stats()
);

// We compute the window-averaged values as above.
std::get<0>(results_pi_)[i] = cast_calc->get_pi_within() / window_avg_denom;
std::get<1>(results_pi_)[i] = cast_calc->get_pi_between() / window_avg_denom;
std::get<2>(results_pi_)[i] = cast_calc->get_pi_total() / window_avg_denom;
// We compute the window-averaged values here.
// Unfortunately, we need to copy this value-by-value, as we want to return
// three independent vectors for user convenienec on the caller's end.
auto const pis = unbiased_calc->get_pi_values( window_length, filter_stats_ );
std::get<0>(results_pi_)[i] = pis.pi_within;
std::get<1>(results_pi_)[i] = pis.pi_between;
std::get<2>(results_pi_)[i] = pis.pi_total;
}

return results_pi_;
Expand Down Expand Up @@ -354,10 +352,6 @@ class FstPoolProcessor

private:

// We force the correct usage of the window averaging policy here,
// so that we make misinterpretation of the values less likely.
WindowAveragePolicy avg_policy_;

// The pairs of sample indices of the variant between which we want to compute FST,
// and the processors to use for these computations,
std::vector<std::pair<size_t, size_t>> sample_pairs_;
Expand All @@ -376,10 +370,6 @@ class FstPoolProcessor
// (number of sample pairs) at which we start using the thread pool.
std::shared_ptr<utils::ThreadPool> thread_pool_;
size_t threading_threshold_ = 0;

// We want to make sure to disallow default constructed instances.
// Bit ugly to do it this way, but works for now.
bool is_default_constructed_ = true;
};

// =================================================================================================
Expand All @@ -396,11 +386,10 @@ class FstPoolProcessor
*/
template<class Calculator, typename... Args>
inline FstPoolProcessor make_fst_pool_processor(
WindowAveragePolicy window_average_policy,
std::vector<size_t> const& pool_sizes,
Args... args
) {
FstPoolProcessor result( window_average_policy );
FstPoolProcessor result;
for( size_t i = 0; i < pool_sizes.size(); ++i ) {
for( size_t j = i + 1; j < pool_sizes.size(); ++j ) {
result.add_calculator(
Expand All @@ -427,12 +416,11 @@ inline FstPoolProcessor make_fst_pool_processor(
*/
template<class Calculator, typename... Args>
inline FstPoolProcessor make_fst_pool_processor(
WindowAveragePolicy window_average_policy,
std::vector<std::pair<size_t, size_t>> const& sample_pairs,
std::vector<size_t> const& pool_sizes,
Args... args
) {
FstPoolProcessor result( window_average_policy );
FstPoolProcessor result;
for( auto const& p : sample_pairs ) {
if( p.first >= pool_sizes.size() || p.second >= pool_sizes.size() ) {
throw std::invalid_argument(
Expand Down Expand Up @@ -465,12 +453,11 @@ inline FstPoolProcessor make_fst_pool_processor(
*/
template<class Calculator, typename... Args>
inline FstPoolProcessor make_fst_pool_processor(
WindowAveragePolicy window_average_policy,
size_t index,
std::vector<size_t> const& pool_sizes,
Args... args
) {
FstPoolProcessor result( window_average_policy );
FstPoolProcessor result;
for( size_t i = 0; i < pool_sizes.size(); ++i ) {
result.add_calculator(
index, i,
Expand All @@ -495,12 +482,11 @@ inline FstPoolProcessor make_fst_pool_processor(
*/
template<class Calculator, typename... Args>
inline FstPoolProcessor make_fst_pool_processor(
WindowAveragePolicy window_average_policy,
size_t index_1, size_t index_2,
std::vector<size_t> const& pool_sizes,
Args... args
) {
FstPoolProcessor result( window_average_policy );
FstPoolProcessor result;
result.add_calculator(
index_1, index_2,
::genesis::utils::make_unique<Calculator>(
Expand Down
Loading

0 comments on commit ddfb3fb

Please sign in to comment.