Skip to content

Commit

Permalink
Bug fix: Optimizing branch lengths in NNI Engine via GP (#484)
Browse files Browse the repository at this point in the history
- Fixes bug that caused NNI search to not optimize new branches when using a GP evaluation engine.
- Adds pybito function that gives NNIEngine read access to its evaluation engine's branch lengths (if using a parsimony engine that does not have branch lengths, throws error.)
- Minor fixes to fully_connect and git_commit functions.

Closes #483 

---------

Co-authored-by: Chris Jennings-Shaffer <[email protected]>
  • Loading branch information
davidrich27 and chrisjenningsshaffer authored Dec 19, 2023
1 parent 7cc6b60 commit d8eded6
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 201 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ _modules
_sources
.RData
.Rhistory
test/_ignore/*
test/_out/*

# Python packaging
dist/
Expand All @@ -60,7 +62,7 @@ src/CMakeLists.txt
.DS_Store
._.DS_Store

# Developer Tools
# Developer Tools
.vscode
*.code-workspace
.gdbinit
Expand Down
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ if(GIT_FOUND)
)
endif()

configure_file("${CMAKE_CURRENT_SOURCE_DIR}/src/sugar_version.hpp.in" "${PROJECT_BINARY_DIR}/src/sugar_version.hpp" @ONLY)

set(CMAKE_SKIP_RPATH ON)
execute_process(COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/get_rpath.py
OUTPUT_VARIABLE BITO_RPATH
Expand Down Expand Up @@ -74,9 +76,6 @@ option(PROFILING "Compile with debugger and profiling symbols" OFF)
function(bito_compile_opts PRODUCT WERROR_)
target_compile_features(${PRODUCT} PUBLIC cxx_std_17)
target_compile_options(${PRODUCT} PUBLIC -Wno-unknown-warning -Wno-unknown-warning-option -Wall)
target_compile_definitions(${PRODUCT} PUBLIC -DGIT_HASH="${GIT_HASH}")
target_compile_definitions(${PRODUCT} PUBLIC -DGIT_BRANCH="${GIT_BRANCH}")
target_compile_definitions(${PRODUCT} PUBLIC -DGIT_TAGS="${GIT_TAGS}")

if(${WERROR_})
target_compile_options(${PRODUCT} PUBLIC -Werror)
Expand All @@ -91,6 +90,7 @@ function(bito_compile_opts PRODUCT WERROR_)
target_include_directories(${PRODUCT} PUBLIC
${PROJECT_BINARY_DIR}/beagle-lib/install/include/libhmsbeagle-1
lib/eigen
${PROJECT_BINARY_DIR}/src
)
endfunction()

Expand Down
96 changes: 96 additions & 0 deletions src/dag_branch_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,50 @@

#include "dag_branch_handler.hpp"

// ** Constructors

DAGBranchHandler::DAGBranchHandler(const size_t count,
std::optional<OptimizationMethod> method)
: branch_lengths_(count), differences_(count) {
if (method.has_value()) {
SetOptimizationMethod(method.value());
}
branch_lengths_.SetDefaultValue(init_default_branch_length_);
branch_lengths_.FillWithDefault();
differences_.SetDefaultValue(init_default_difference_);
differences_.FillWithDefault();
}

DAGBranchHandler::DAGBranchHandler(GPDAG& dag, std::optional<OptimizationMethod> method)
: branch_lengths_(dag), differences_(dag), dag_(&dag) {
if (method.has_value()) {
SetOptimizationMethod(method.value());
}
branch_lengths_.SetDefaultValue(init_default_branch_length_);
branch_lengths_.FillWithDefault();
differences_.SetDefaultValue(init_default_difference_);
differences_.FillWithDefault();
}

// ** Comparators

int DAGBranchHandler::Compare(const DAGBranchHandler& lhs,
const DAGBranchHandler& rhs) {
if (lhs.size() != rhs.size()) {
return lhs.size() - rhs.size();
}
for (EdgeId edge_id{0}; edge_id < lhs.size(); edge_id++) {
if (lhs.Get(edge_id) != rhs.Get(edge_id)) {
return lhs.Get(edge_id) - rhs.Get(edge_id);
}
}
return 0;
}

bool operator==(const DAGBranchHandler& lhs, const DAGBranchHandler& rhs) {
return DAGBranchHandler::Compare(lhs, rhs) == 0;
}

// ** Branch Length Map

DAGBranchHandler::BranchLengthMap DAGBranchHandler::BuildBranchLengthMap(
Expand All @@ -24,6 +68,56 @@ void DAGBranchHandler::ApplyBranchLengthMap(
}
}

// ** Static Functions

RootedTree DAGBranchHandler::BuildTreeWithBranchLengthsFromTopology(
const GPDAG& dag, const DAGBranchHandler& dag_branch_handler,
const Node::Topology& topology) {
Tree::BranchLengthVector tree_branch_lengths(2 * topology->LeafCount() - 1, 0.0);

topology->RootedPCSPPreorder(
[&dag, &dag_branch_handler, &tree_branch_lengths](
const Node* sister, const Node* focal, const Node* child_left,
const Node* child_right) {
Bitset parent_subsplit = Bitset::Subsplit(sister->Leaves(), focal->Leaves());
Bitset child_subsplit =
Bitset::Subsplit(child_left->Leaves(), child_right->Leaves());
EdgeId focal_edge_id = dag.GetEdgeIdx(parent_subsplit, child_subsplit);
tree_branch_lengths[focal->Id()] = dag_branch_handler(focal_edge_id);

// If adjacent nodes go to leaves.
if (sister->IsLeaf()) {
Bitset subsplit = Bitset::LeafSubsplitOfNonemptyClade(sister->Leaves());
EdgeId edge_id = dag.GetEdgeIdx(parent_subsplit, subsplit);
tree_branch_lengths[sister->Id()] = dag_branch_handler(edge_id);
}
if (child_left->IsLeaf()) {
Bitset subsplit = Bitset::LeafSubsplitOfNonemptyClade(child_left->Leaves());
EdgeId edge_id = dag.GetEdgeIdx(child_subsplit, subsplit);
tree_branch_lengths[child_left->Id()] = dag_branch_handler(edge_id);
}
if (child_right->IsLeaf()) {
Bitset subsplit = Bitset::LeafSubsplitOfNonemptyClade(child_right->Leaves());
EdgeId edge_id = dag.GetEdgeIdx(child_subsplit, subsplit);
tree_branch_lengths[child_right->Id()] = dag_branch_handler(edge_id);
}
},
false);

return RootedTree(topology, std::move(tree_branch_lengths));
}

void DAGBranchHandler::CopyOverBranchLengths(const DAGBranchHandler& src,
DAGBranchHandler& dest) {
const auto& src_dag = src.GetDAG();
const auto& dest_dag = dest.GetDAG();
for (EdgeId dest_id = 0; dest_id < dest_dag.EdgeCountWithLeafSubsplits(); dest_id++) {
const auto& pcsp = dest_dag.GetDAGEdgeBitset(dest_id);
const auto src_id = src_dag.GetEdgeIdx(pcsp);
dest.Get(dest_id) = src.Get(src_id);
}
}

// ** Optimization

void DAGBranchHandler::OptimizeBranchLength(const EdgeId edge_id, const PVId parent_id,
Expand Down Expand Up @@ -51,6 +145,8 @@ void DAGBranchHandler::OptimizeBranchLength(const EdgeId edge_id, const PVId par
}
}

// ** Branch

void DAGBranchHandler::BrentOptimization(const EdgeId edge_id, const PVId parent_id,
const PVId child_id) {
Assert(brent_nongrad_func_ != nullptr,
Expand Down
85 changes: 6 additions & 79 deletions src/dag_branch_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,13 @@
class DAGBranchHandler {
public:
DAGBranchHandler(const size_t count,
std::optional<OptimizationMethod> method = std::nullopt)
: branch_lengths_(count), differences_(count) {
if (method.has_value()) {
SetOptimizationMethod(method.value());
}
branch_lengths_.SetDefaultValue(init_default_branch_length_);
branch_lengths_.FillWithDefault();
differences_.SetDefaultValue(init_default_difference_);
differences_.FillWithDefault();
}
DAGBranchHandler(GPDAG& dag, std::optional<OptimizationMethod> method = std::nullopt)
: branch_lengths_(dag), differences_(dag), dag_(&dag) {
if (method.has_value()) {
SetOptimizationMethod(method.value());
}
branch_lengths_.SetDefaultValue(init_default_branch_length_);
branch_lengths_.FillWithDefault();
differences_.SetDefaultValue(init_default_difference_);
differences_.FillWithDefault();
}
std::optional<OptimizationMethod> method = std::nullopt);
DAGBranchHandler(GPDAG& dag, std::optional<OptimizationMethod> method = std::nullopt);

// ** Comparators

static int Compare(const DAGBranchHandler& lhs, const DAGBranchHandler& rhs) {
if (lhs.size() != rhs.size()) {
return lhs.size() - rhs.size();
}
for (EdgeId edge_id{0}; edge_id < lhs.size(); edge_id++) {
if (lhs.Get(edge_id) != rhs.Get(edge_id)) {
return lhs.Get(edge_id) - rhs.Get(edge_id);
}
}
return 0;
}
friend bool operator==(const DAGBranchHandler& lhs, const DAGBranchHandler& rhs) {
return Compare(lhs, rhs) == 0;
}
static int Compare(const DAGBranchHandler& lhs, const DAGBranchHandler& rhs);
friend bool operator==(const DAGBranchHandler& lhs, const DAGBranchHandler& rhs);

// ** Counts

Expand Down Expand Up @@ -254,55 +224,12 @@ class DAGBranchHandler {
// out PCSP bitset to find corresponding DAG EdgeId, to find branch length.
static RootedTree BuildTreeWithBranchLengthsFromTopology(
const GPDAG& dag, const DAGBranchHandler& dag_branch_handler,
const Node::Topology& topology) {
Tree::BranchLengthVector tree_branch_lengths(2 * topology->LeafCount() - 1, 0.0);

topology->RootedPCSPPreorder(
[&dag, &dag_branch_handler, &tree_branch_lengths](
const Node* sister, const Node* focal, const Node* child_left,
const Node* child_right) {
Bitset parent_subsplit = Bitset::Subsplit(sister->Leaves(), focal->Leaves());
Bitset child_subsplit =
Bitset::Subsplit(child_left->Leaves(), child_right->Leaves());
EdgeId focal_edge_id = dag.GetEdgeIdx(parent_subsplit, child_subsplit);
tree_branch_lengths[focal->Id()] = dag_branch_handler(focal_edge_id);

// If adjacent nodes go to leaves.
if (sister->IsLeaf()) {
Bitset subsplit = Bitset::LeafSubsplitOfNonemptyClade(sister->Leaves());
EdgeId edge_id = dag.GetEdgeIdx(parent_subsplit, subsplit);
tree_branch_lengths[sister->Id()] = dag_branch_handler(edge_id);
}
if (child_left->IsLeaf()) {
Bitset subsplit = Bitset::LeafSubsplitOfNonemptyClade(child_left->Leaves());
EdgeId edge_id = dag.GetEdgeIdx(child_subsplit, subsplit);
tree_branch_lengths[child_left->Id()] = dag_branch_handler(edge_id);
}
if (child_right->IsLeaf()) {
Bitset subsplit =
Bitset::LeafSubsplitOfNonemptyClade(child_right->Leaves());
EdgeId edge_id = dag.GetEdgeIdx(child_subsplit, subsplit);
tree_branch_lengths[child_right->Id()] = dag_branch_handler(edge_id);
}
},
false);

return RootedTree(topology, std::move(tree_branch_lengths));
}
const Node::Topology& topology);

// Copies branch lengths from one handler to another. Base DAG of dest handler must be
// a subgraph of src handler.
static void CopyOverBranchLengths(const DAGBranchHandler& src,
DAGBranchHandler& dest) {
const auto& src_dag = src.GetDAG();
const auto& dest_dag = dest.GetDAG();
for (EdgeId dest_id = 0; dest_id < dest_dag.EdgeCountWithLeafSubsplits();
dest_id++) {
const auto& pcsp = dest_dag.GetDAGEdgeBitset(dest_id);
const auto src_id = src_dag.GetEdgeIdx(pcsp);
dest.Get(dest_id) = src.Get(src_id);
}
}
DAGBranchHandler& dest);

protected:
// ** Branch Length Optimization Helpers
Expand Down
4 changes: 2 additions & 2 deletions src/dag_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ class DAGData {
VectorType &GetData() { return data_vec_; }
const VectorType &GetData() const { return data_vec_; }
// Get data corresponding to the DAG elements.
VectorType GetDAGData() { return data_vec_.segment(0, GetCount()); }
VectorType GetDAGData() const { return data_vec_.segment(0, GetCount()); }
// Get data corresponding to the DAG elements, including spare elements.
VectorType GetPaddedDAGData() { return data_vec_.segment(0, GetPaddedCount()); }
VectorType GetPaddedDAGData() const { return data_vec_.segment(0, GetPaddedCount()); }
// Get data corresponding to the DAG elements.
void SetDAGData(const VectorType &data_vec) {
Assert(GetCount() == size_t(data_vec.size()),
Expand Down
13 changes: 13 additions & 0 deletions src/nni_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,19 @@ void NNIEngine::UpdateEvalEngineAfterModifyingDAG(
}
}

const DAGBranchHandler &NNIEngine::GetDAGBranchHandler() const {
if (IsEvalEngineInUse(NNIEvalEngineType::GPEvalEngine)) {
return GetGPEvalEngine().GetDAGBranchHandler();
}
if (IsEvalEngineInUse(NNIEvalEngineType::TPEvalEngineViaLikelihood)) {
return GetTPEvalEngine().GetDAGBranchHandler();
}
if (IsEvalEngineInUse(NNIEvalEngineType::TPEvalEngineViaParsimony)) {
return GetTPEvalEngine().GetDAGBranchHandler();
}
Failwith("Invalid given EvalEngineType.");
}

// ** Runners

void NNIEngine::Run(const bool is_quiet) {
Expand Down
11 changes: 8 additions & 3 deletions src/nni_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,14 @@ class NNIEngine {
// Reset number of iterations.
void ResetIterationCount() { iter_count_ = 0; }

// ** Counts

// ** NNI Evaluation Engine

// Set GP Engine.
NNIEvalEngineViaGP &MakeGPEvalEngine(GPEngine *gp_engine);
// Set TP Engine.
NNIEvalEngineViaTP &MakeTPEvalEngine(TPEngine *tp_engine);
// Check if evaluation engine is currently in use.
bool IsEvalEngineInUse(const NNIEvalEngineType eval_engine_type) {
bool IsEvalEngineInUse(const NNIEvalEngineType eval_engine_type) const {
return eval_engine_in_use_[eval_engine_type];
}
// Remove all evaluation engines from use.
Expand Down Expand Up @@ -258,6 +256,13 @@ class NNIEngine {
void GrowEvalEngineForAdjacentNNIs(const bool via_reference = true,
const bool use_unique_temps = false);

// Get evaluation engine's branch length handler.
const DAGBranchHandler &GetDAGBranchHandler() const;
// Get branch lengths.
const EigenVectorXd GetBranchLengths() const {
return GetDAGBranchHandler().GetBranchLengths().GetDAGData();
}

// ** Runners
// These start the engine, which procedurally ranks and adds (and maybe removes) NNIs
// to the DAG, until some termination criteria has been satisfied.
Expand Down
20 changes: 7 additions & 13 deletions src/nni_evaluation_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ void NNIEvalEngineViaGP::UpdateEngineAfterModifyingDAG(
const size_t prev_node_count, const Reindexer &node_reindexer,
const size_t prev_edge_count, const Reindexer &edge_reindexer) {
using namespace GPOperations;
const bool copy_branch_lengths = false;

auto &branch_handler = GetGPEngine().GetBranchLengthHandler();
// Find all new edge ids.
std::set<EdgeId> new_edge_ids;
Expand All @@ -110,7 +108,7 @@ void NNIEvalEngineViaGP::UpdateEngineAfterModifyingDAG(
branch_handler(edge_id) = branch_handler.GetDefaultBranchLength();
}
// Copy over branch lengths from pre-NNI to post-NNI.
if (copy_branch_lengths) {
if (copy_new_edges_) {
for (const auto &[pre_nni, nni] : pre_nni_to_nni) {
CopyGPEngineDataAfterAddingNNI(pre_nni, nni);
}
Expand All @@ -132,16 +130,11 @@ void NNIEvalEngineViaGP::UpdateEngineAfterModifyingDAG(
GetGPEngine().ProcessOperations(GetDAG().PopulatePLVs());

// Optimize branch lengths.
if (IsOptimizeNewEdges()) {
for (const auto &[pre_nni, nni] : pre_nni_to_nni) {
std::ignore = pre_nni;
NNIBranchLengthOptimization(nni, new_edge_ids);
}
GetGPEngine().ProcessOperations(GetDAG().PopulatePLVs());
if (optimize_new_edges_) {
BranchLengthOptimization();
}

GetGPEngine().ProcessOperations(GetDAG().ComputeLikelihoods());
auto likelihoods = GetGPEngine().GetPerGPCSPLogLikelihoods();
}

void NNIEvalEngineViaGP::CopyGPEngineDataAfterAddingNNI(const NNIOperation &pre_nni,
Expand Down Expand Up @@ -841,10 +834,11 @@ void NNIEvalEngineViaGP::BranchLengthOptimization() {

void NNIEvalEngineViaGP::BranchLengthOptimization(
const std::set<EdgeId> &edges_to_optimize) {
const auto ops = GetDAG().BranchLengthOptimization(edges_to_optimize);
const auto update_ops = GetDAG().PopulatePLVs();
const auto optimize_ops = GetDAG().BranchLengthOptimization(edges_to_optimize);
for (size_t iter = 0; iter < GetOptimizationMaxIteration(); iter++) {
GetGPEngine().ProcessOperations(ops);
GetGPEngine().ProcessOperations(GetDAG().PopulatePLVs());
GetGPEngine().ProcessOperations(optimize_ops);
GetGPEngine().ProcessOperations(update_ops);
}
}

Expand Down
Loading

0 comments on commit d8eded6

Please sign in to comment.