Skip to content

Commit

Permalink
Add weighted clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnSeidel committed Nov 14, 2022
1 parent 1fea43d commit bd080e6
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 13 deletions.
18 changes: 17 additions & 1 deletion src/clustering/Clustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#include "Util.h"
#include "itoa.h"
#include "Timer.h"
#include "SequenceWeights.h"

Clustering::Clustering(const std::string &seqDB, const std::string &seqDBIndex,
const std::string &alnDB, const std::string &alnDBIndex,
const std::string &outDB, const std::string &outDBIndex,
const std::string &sequenceWeightFile,
unsigned int maxIteration, int similarityScoreType, int threads, int compressed) : maxIteration(maxIteration),
similarityScoreType(similarityScoreType),
threads(threads),
Expand All @@ -16,7 +18,21 @@ Clustering::Clustering(const std::string &seqDB, const std::string &seqDBIndex,
outDBIndex(outDBIndex) {

seqDbr = new DBReader<unsigned int>(seqDB.c_str(), seqDBIndex.c_str(), threads, DBReader<unsigned int>::USE_INDEX);
seqDbr->open(DBReader<unsigned int>::SORT_BY_LENGTH);

if (!sequenceWeightFile.empty()) {

seqDbr->open(DBReader<unsigned int>::SORT_BY_ID);

SequenceWeights *sequenceWeights = new SequenceWeights(sequenceWeightFile.c_str());
float localid2weight[seqDbr->getSize()];
for (size_t id = 0; id < seqDbr->getSize(); id++) {
size_t key = seqDbr->getDbKey(id);
localid2weight[id] = sequenceWeights->getWeightById(key);
}
seqDbr->sortIndex(localid2weight);

} else
seqDbr->open(DBReader<unsigned int>::SORT_BY_LENGTH);

alnDbr = new DBReader<unsigned int>(alnDB.c_str(), alnDBIndex.c_str(), threads, DBReader<unsigned int>::USE_DATA|DBReader<unsigned int>::USE_INDEX);
alnDbr->open(DBReader<unsigned int>::NOSORT);
Expand Down
1 change: 1 addition & 0 deletions src/clustering/Clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Clustering {
Clustering(const std::string &seqDB, const std::string &seqDBIndex,
const std::string &alnResultsDB, const std::string &alnResultsDBIndex,
const std::string &outDB, const std::string &outDBIndex,
const std::string &weightFileName,
unsigned int maxIteration, int similarityScoreType, int threads, int compressed);

void run(int mode);
Expand Down
2 changes: 1 addition & 1 deletion src/clustering/ClusteringAlgorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ void ClusteringAlgorithms::greedyIncrementalLowMem( unsigned int *assignedcluste
#pragma omp for schedule(dynamic, 1000)
for(size_t i = 0; i < dbSize; i++) {
unsigned int clusterKey = seqDbr->getDbKey(i);
unsigned int clusterId = seqDbr->getId(clusterKey);
unsigned int clusterId = i;

// try to set your self as cluster centriod
// if some other cluster covered
Expand Down
2 changes: 1 addition & 1 deletion src/clustering/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ int clust(int argc, const char **argv, const Command& command) {
par.parseParameters(argc, argv, command, true, 0, 0);

Clustering clu(par.db1, par.db1Index, par.db2, par.db2Index,
par.db3, par.db3Index, par.maxIteration,
par.db3, par.db3Index, par.weightFile, par.maxIteration,
par.similarityScoreType, par.threads, par.compressed);
clu.run(par.clusteringMode);
return EXIT_SUCCESS;
Expand Down
2 changes: 2 additions & 0 deletions src/commons/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(commons_header_files
commons/PatternCompiler.h
commons/ScoreMatrix.h
commons/Sequence.h
commons/SequenceWeights.h
commons/StringBlock.h
commons/SubstitutionMatrix.h
commons/SubstitutionMatrixProfileStates.h
Expand Down Expand Up @@ -69,6 +70,7 @@ set(commons_source_files
commons/ProfileStates.cpp
commons/LibraryReader.cpp
commons/Sequence.cpp
commons/SequenceWeights.cpp
commons/SubstitutionMatrix.cpp
commons/tantan.cpp
commons/UniprotKB.cpp
Expand Down
27 changes: 26 additions & 1 deletion src/commons/DBReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ template <typename T> bool DBReader<T>::open(int accessType){
template<typename T>
void DBReader<T>::sortIndex(bool) {
}
template<typename T>
void DBReader<T>::sortIndex(float*) {
}

template<typename T>
bool DBReader<T>::isSortedByOffset(){
Expand All @@ -234,8 +237,31 @@ void DBReader<std::string>::sortIndex(bool isSortedById) {
}
}

template<>
void DBReader<unsigned int>::sortIndex(float *weights) {

this->accessType=DBReader::SORT_BY_WEIGHTS;
std::pair<unsigned int, float> *sortForMapping = new std::pair<unsigned int, float>[size];
id2local = new unsigned int[size];
local2id = new unsigned int[size];
incrementMemory(sizeof(unsigned int) * 2 * size);
for (size_t i = 0; i < size; i++) {
id2local[i] = i;
local2id[i] = i;
sortForMapping[i] = std::make_pair(i, weights[i]);
}
//this sort has to be stable to assure same clustering results
SORT_PARALLEL(sortForMapping, sortForMapping + size, comparePairByWeight());
for (size_t i = 0; i < size; i++) {
id2local[sortForMapping[i].first] = i;
local2id[i] = sortForMapping[i].first;
}
delete[] sortForMapping;
}

template<>
void DBReader<unsigned int>::sortIndex(bool isSortedById) {

// First, we sort the index by IDs and we keep track of the original
// ordering in mappingToOriginalIndex array
size_t* mappingToOriginalIndex=NULL;
Expand Down Expand Up @@ -294,7 +320,6 @@ void DBReader<unsigned int>::sortIndex(bool isSortedById) {
mappingToOriginalIndex[i] = i;
}
}

if (accessType == SORT_BY_LENGTH) {
// sort the entries by the length of the sequences
std::pair<unsigned int, unsigned int> *sortForMapping = new std::pair<unsigned int, unsigned int>[size];
Expand Down
16 changes: 16 additions & 0 deletions src/commons/DBReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class DBReader : public MemoryTracker {
static const int HARDNOSORT = 6; // do not even sort by ids.
static const int SORT_BY_ID_OFFSET = 7;
static const int SORT_BY_OFFSET = 8; // only offset sorting saves memory and does not support random access
static const int SORT_BY_WEIGHTS= 9;


static const unsigned int USE_INDEX = 0;
Expand Down Expand Up @@ -317,6 +318,7 @@ class DBReader : public MemoryTracker {
void mlock();

void sortIndex(bool isSortedById);
void sortIndex(float *weights);
bool isSortedByOffset();

void unmapData();
Expand Down Expand Up @@ -392,6 +394,20 @@ class DBReader : public MemoryTracker {
}
};

struct comparePairByWeight {
bool operator() (const std::pair<unsigned int, float>& lhs, const std::pair<unsigned int, float>& rhs) const{
if(lhs.second > rhs.second)
return true;
if(rhs.second > lhs.second)
return false;
if(lhs.first < rhs.first )
return true;
if(rhs.first < lhs.first )
return false;
return false;
}
};

struct comparePairByIdAndOffset {
bool operator() (const std::pair<unsigned int, Index>& lhs, const std::pair<unsigned int, Index>& rhs) const{
if(lhs.second.id < rhs.second.id)
Expand Down
10 changes: 9 additions & 1 deletion src/commons/Parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ Parameters::Parameters():
PARAM_PICK_N_SIMILAR(PARAM_PICK_N_SIMILAR_ID, "--pick-n-sim-kmer", "Add N similar to search", "Add N similar k-mers to search", typeid(int), (void *) &pickNbest, "^[1-9]{1}[0-9]*$", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT),
PARAM_ADJUST_KMER_LEN(PARAM_ADJUST_KMER_LEN_ID, "--adjust-kmer-len", "Adjust k-mer length", "Adjust k-mer length based on specificity (only for nucleotides)", typeid(bool), (void *) &adjustKmerLength, "", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT),
PARAM_RESULT_DIRECTION(PARAM_RESULT_DIRECTION_ID, "--result-direction", "Result direction", "result is 0: query, 1: target centric", typeid(int), (void *) &resultDirection, "^[0-1]{1}$", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT),
PARAM_WEIGHT_FILE(PARAM_WEIGHT_FILE_ID, "--weights", "Weight file name", "Weights used for cluster priorization", typeid(std::string), (void*) &weightFile, "", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT ),
PARAM_WEIGHT_THR(PARAM_WEIGHT_THR_ID, "--weightThr", "Weight threshold", "Weight threshold used for cluster priorization", typeid(float), (void*) &weightThr, "^[0-9]*(\\.[0-9]+)?$", MMseqsParameter::COMMAND_CLUSTLINEAR | MMseqsParameter::COMMAND_EXPERT ),
// workflow
PARAM_RUNNER(PARAM_RUNNER_ID, "--mpi-runner", "MPI runner", "Use MPI on compute cluster with this MPI command (e.g. \"mpirun -np 42\")", typeid(std::string), (void *) &runner, "", MMseqsParameter::COMMAND_COMMON | MMseqsParameter::COMMAND_EXPERT),
PARAM_REUSELATEST(PARAM_REUSELATEST_ID, "--force-reuse", "Force restart with latest tmp", "Reuse tmp filse in tmp/latest folder ignoring parameters and version changes", typeid(bool), (void *) &reuseLatest, "", MMseqsParameter::COMMAND_COMMON | MMseqsParameter::COMMAND_EXPERT),
Expand Down Expand Up @@ -435,6 +436,8 @@ Parameters::Parameters():
clust.push_back(&PARAM_THREADS);
clust.push_back(&PARAM_COMPRESSED);
clust.push_back(&PARAM_V);
clust.push_back(&PARAM_WEIGHT_FILE);
clust.push_back(&PARAM_WEIGHT_THR);
// rescorediagonal
rescorediagonal.push_back(&PARAM_SUB_MAT);
Expand Down Expand Up @@ -914,6 +917,8 @@ Parameters::Parameters():
kmermatcher.push_back(&PARAM_THREADS);
kmermatcher.push_back(&PARAM_COMPRESSED);
kmermatcher.push_back(&PARAM_V);
kmermatcher.push_back(&PARAM_WEIGHT_FILE);
kmermatcher.push_back(&PARAM_WEIGHT_THR);
// kmermatcher
kmersearch.push_back(&PARAM_SEED_SUB_MAT);
Expand Down Expand Up @@ -2431,6 +2436,9 @@ void Parameters::setDefaults() {
pickNbest = 1;
adjustKmerLength = false;
resultDirection = Parameters::PARAM_RESULT_DIRECTION_TARGET;
weightThr = 0.9;
weightFile = "";

// result2stats
stat = "";

Expand Down
6 changes: 6 additions & 0 deletions src/commons/Parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ class Parameters {
std::string spacedKmerPattern; // User-specified kmer pattern
std::string localTmp; // Local temporary path


// ALIGNMENT
int alignmentMode; // alignment mode 0=fastest on parameters,
// 1=score only, 2=score, cov, start/end pos, 3=score, cov, start/end pos, seq.id,
Expand Down Expand Up @@ -526,6 +527,8 @@ class Parameters {
int pickNbest;
int adjustKmerLength;
int resultDirection;
float weightThr;
std::string weightFile;

// indexdb
int checkCompatible;
Expand Down Expand Up @@ -828,6 +831,9 @@ class Parameters {
PARAMETER(PARAM_PICK_N_SIMILAR)
PARAMETER(PARAM_ADJUST_KMER_LEN)
PARAMETER(PARAM_RESULT_DIRECTION)
PARAMETER(PARAM_WEIGHT_FILE)
PARAMETER(PARAM_WEIGHT_THR)

// workflow
PARAMETER(PARAM_RUNNER)
PARAMETER(PARAM_REUSELATEST)
Expand Down
55 changes: 55 additions & 0 deletions src/commons/SequenceWeights.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//
// Created by annika on 09.11.22.
//
#include "SequenceWeights.h"
#include "Util.h"
#include "Debug.h"

#include <algorithm>
#include <fstream>

SequenceWeights::SequenceWeights(const char* dataFileName) {

//parse file and fill weightIndex
std::ifstream tsv(dataFileName);
if (tsv.fail()) {
Debug(Debug::ERROR) << "File " << dataFileName << " not found!\n";
EXIT(EXIT_FAILURE);
}

char keyData[255];
std::string line;
this->indexSize = 0;
unsigned int pos = 0;
while(std::getline(tsv, line)) {
this->indexSize++;
}

this->weightIndex = new WeightIndexEntry[this->indexSize];

tsv.clear();
tsv.seekg(0);

while(std::getline(tsv, line)) {
char *current = (char *) line.c_str();
Util::parseKey(current, keyData);
const std::string key(keyData);
unsigned int keyId = strtoull(key.c_str(), NULL, 10);

char *restStart = current + key.length();
restStart = restStart + Util::skipWhitespace(restStart);
float weight = static_cast<float>(strtod(restStart, NULL));
this->weightIndex[pos].id = keyId;
this->weightIndex[pos].weight = weight;
pos++;
}
}

float SequenceWeights::getWeightById(unsigned int id) {

WeightIndexEntry val;
val.id = id;
size_t pos = std::upper_bound(weightIndex, weightIndex + indexSize, val, WeightIndexEntry::compareByIdOnly) - weightIndex;
return (pos < indexSize && weightIndex[pos].id == id ) ? weightIndex[pos].weight : std::numeric_limits<float>::min();
}

30 changes: 30 additions & 0 deletions src/commons/SequenceWeights.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//
// Created by annika on 09.11.22.
//

#ifndef MMSEQS_SEQUENCEWEIGHTS_H
#define MMSEQS_SEQUENCEWEIGHTS_H

class SequenceWeights{
public:
struct WeightIndexEntry {
unsigned int id;
float weight;

static bool compareByIdOnly(const WeightIndexEntry &x, const WeightIndexEntry &y) {
return x.id <= y.id;
}
};

WeightIndexEntry *weightIndex;
unsigned int indexSize;

SequenceWeights(const char* dataFileName);

~SequenceWeights();

float getWeightById(unsigned int id);
};


#endif //MMSEQS_SEQUENCEWEIGHTS_H
Loading

0 comments on commit bd080e6

Please sign in to comment.