Skip to content

Commit

Permalink
Merge pull request #343 from steineggerlab/foldseek_multimer_bottleneck
Browse files Browse the repository at this point in the history
Foldseek multimer bottleneck solved
  • Loading branch information
Woosub-Kim authored Sep 6, 2024
2 parents 868bfb1 + 3bcdaba commit be9fc33
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 75 deletions.
11 changes: 3 additions & 8 deletions src/strucclustutils/MultimerUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@ const double TOO_SMALL_CV = 0.1;
const double FILTERED_OUT = 0.0;
const unsigned int INITIALIZED_LABEL = 0;
const unsigned int MIN_PTS = 2;
const float DEFAULT_EPS = 0.1;
const float LEARNING_RATE = 0.1;
const float TM_SCORE_MARGIN = 0.7;
const float DEF_TM_SCORE = -1.0;
const int UNINITIALIZED = 0;
const unsigned int MULTIPLE_CHAINED_COMPLEX = 2;
const unsigned int SIZE_OF_SUPERPOSITION_VECTOR = 12;
typedef std::pair<std::string, std::string> compNameChainName_t;
typedef std::map<unsigned int, unsigned int> chainKeyToComplexId_t;
typedef std::map<unsigned int, std::vector<unsigned int>> complexIdToChainKeys_t;
typedef std::vector<unsigned int> cluster_t;
typedef std::map<std::pair<unsigned int, unsigned int>, float> distMap_t;
typedef std::string resultToWrite_t;
typedef std::string chainName_t;
typedef std::pair<unsigned int, resultToWrite_t> resultToWriteWithKey_t;
Expand Down Expand Up @@ -96,13 +92,12 @@ struct ChainToChainAln {
unsigned int label;
float tmScore;

float getDistance(const ChainToChainAln &o) {
float getDistance(const ChainToChainAln &o) const {
float dist = 0;
for (size_t i=0; i<SIZE_OF_SUPERPOSITION_VECTOR; i++) {
dist += std::pow(superposition[i] - o.superposition[i], 2);
}
dist = std::sqrt(dist);
return dist;
return std::sqrt(dist);
}

void free() {
Expand Down Expand Up @@ -229,4 +224,4 @@ static ComplexDataHandler parseScoreComplexResult(const char *data, Matcher::res
return {assId, qTmScore, tTmScore, uString, tString, true};
}

#endif //FOLDSEEK_MULTIMERUTIL_H
#endif //FOLDSEEK_MULTIMERUTIL_H
145 changes: 78 additions & 67 deletions src/strucclustutils/scoremultimer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "Coordinate16.h"
#include "MultimerUtil.h"
#include "set"

#include "unordered_set"
#ifdef OPENMP
#include <omp.h>
#endif
Expand Down Expand Up @@ -193,7 +193,7 @@ class DBSCANCluster {
bool getAlnClusters() {
// rbh filter
filterAlnsByRBH();
fillDistMap();
fillDistMatrix();
// To skip DBSCAN clustering when alignments are few enough.
if (searchResult.alnVec.size() <= maximumClusterSize)
return checkClusteringNecessity();
Expand All @@ -215,95 +215,106 @@ class DBSCANCluster {
unsigned int minimumClusterSize;
std::vector<unsigned int> neighbors;
std::vector<unsigned int> neighborsOfCurrNeighbor;
std::unordered_set<unsigned int> foundNeighbors;
std::vector<NeighborsWithDist> neighborsWithDist;
std::set<unsigned int> qFoundChainKeys;
std::set<unsigned int> dbFoundChainKeys;
distMap_t distMap;
std::unordered_set<unsigned int> qFoundChainKeys;
std::unordered_set<unsigned int> dbFoundChainKeys;
std::vector<float> distMatrix;
std::vector<cluster_t> currClusters;
std::set<cluster_t> &finalClusters;
std::map<unsigned int, float> qBestTmScore;
std::map<unsigned int, float> dbBestTmScore;

bool runDBSCAN() {
initializeAlnLabels();
if (eps >= maxDist)
return finishDBSCAN();

for (size_t centerAlnIdx=0; centerAlnIdx < searchResult.alnVec.size(); centerAlnIdx++) {
ChainToChainAln &centerAln = searchResult.alnVec[centerAlnIdx];
if (centerAln.label != 0)
continue;

getNeighbors(centerAlnIdx, neighbors);
if (neighbors.size() < MIN_PTS)
continue;
unsigned int neighborIdx;
unsigned int neighborAlnIdx;
while (eps < maxDist) {
initializeAlnLabels();
for (size_t centerAlnIdx = 0; centerAlnIdx < searchResult.alnVec.size(); centerAlnIdx++) {
ChainToChainAln &centerAln = searchResult.alnVec[centerAlnIdx];
if (centerAln.label != 0)
continue;

centerAln.label = ++cLabel;
size_t neighborIdx = 0;
while (neighborIdx < neighbors.size()) {
unsigned int neighborAlnIdx = neighbors[neighborIdx++];
if (centerAlnIdx == neighborAlnIdx)
getNeighbors(centerAlnIdx, neighbors);
if (neighbors.size() < MIN_PTS)
continue;

ChainToChainAln &neighborAln = searchResult.alnVec[neighborAlnIdx];
neighborAln.label = cLabel;
getNeighbors(neighborAlnIdx, neighborsOfCurrNeighbor);
if (neighborsOfCurrNeighbor.size() < MIN_PTS)
centerAln.label = ++cLabel;
foundNeighbors.clear();
foundNeighbors.insert(neighbors.begin(), neighbors.end());
neighborIdx = 0;
while (neighborIdx < neighbors.size()) {
neighborAlnIdx = neighbors[neighborIdx++];
if (centerAlnIdx == neighborAlnIdx)
continue;

ChainToChainAln &neighborAln = searchResult.alnVec[neighborAlnIdx];
neighborAln.label = cLabel;
getNeighbors(neighborAlnIdx, neighborsOfCurrNeighbor);
if (neighborsOfCurrNeighbor.size() < MIN_PTS)
continue;

for (auto neighbor : neighborsOfCurrNeighbor) {
if (foundNeighbors.insert(neighbor).second)
neighbors.emplace_back(neighbor);
}
}
if (neighbors.size() > maximumClusterSize || checkChainRedundancy())
getNearestNeighbors(centerAlnIdx);

// too small cluster
if (neighbors.size() < currMaxClusterSize)
continue;

for (auto neighbor : neighborsOfCurrNeighbor) {
if (std::find(neighbors.begin(), neighbors.end(), neighbor) == neighbors.end())
neighbors.emplace_back(neighbor);
// new Biggest cluster
if (neighbors.size() > currMaxClusterSize) {
currMaxClusterSize = neighbors.size();
currClusters.clear();
}
SORT_SERIAL(neighbors.begin(), neighbors.end());
currClusters.emplace_back(neighbors);
}
if (neighbors.size() > maximumClusterSize || checkChainRedundancy())
getNearestNeighbors(centerAlnIdx);

// too small cluster
if (neighbors.size() < currMaxClusterSize)
continue;
if (!finalClusters.empty() && currClusters.empty())
return finishDBSCAN();

// new Biggest cluster
if (neighbors.size() >currMaxClusterSize) {
currMaxClusterSize = neighbors.size();
currClusters.clear();
if (currMaxClusterSize < prevMaxClusterSize)
return finishDBSCAN();

if (currMaxClusterSize > prevMaxClusterSize) {
finalClusters.clear();
prevMaxClusterSize = currMaxClusterSize;
}
SORT_SERIAL(neighbors.begin(), neighbors.end());
currClusters.emplace_back(neighbors);
}

if (!finalClusters.empty() && currClusters.empty())
return finishDBSCAN();
if (currMaxClusterSize >= minimumClusterSize)
finalClusters.insert(currClusters.begin(), currClusters.end());

if (currMaxClusterSize < prevMaxClusterSize)
return finishDBSCAN();
if (currMaxClusterSize == maximumClusterSize && finalClusters.size() == maximumClusterNum)
return finishDBSCAN();

if (currMaxClusterSize > prevMaxClusterSize) {
finalClusters.clear();
prevMaxClusterSize = currMaxClusterSize;
eps += learningRate;
}
return finishDBSCAN();
}

if (currMaxClusterSize >= minimumClusterSize)
finalClusters.insert(currClusters.begin(), currClusters.end());

if (currMaxClusterSize==maximumClusterSize && finalClusters.size() == maximumClusterNum)
return finishDBSCAN();

eps += learningRate;
return runDBSCAN();
size_t getDistMatrixIndex(size_t i, size_t j) const {
if (i > j) std::swap(i, j); // Ensure i <= j for symmetry
size_t n = searchResult.alnVec.size();
return (2 * n *i - i - i * i) / 2 + j - i - 1;
}

void fillDistMap() {
void fillDistMatrix() {
size_t size = searchResult.alnVec.size();
float dist;
distMap.clear();
for (size_t i=0; i < searchResult.alnVec.size(); i++) {
ChainToChainAln &prevAln = searchResult.alnVec[i];
for (size_t j = i+1; j < searchResult.alnVec.size(); j++) {
ChainToChainAln &currAln = searchResult.alnVec[j];
distMatrix.resize(size * (size - 1) / 2, 0.0f);
for (size_t i = 0; i < searchResult.alnVec.size(); i++) {
const ChainToChainAln &prevAln = searchResult.alnVec[i];
for (size_t j = i + 1; j < searchResult.alnVec.size(); j++) {
const ChainToChainAln &currAln = searchResult.alnVec[j];
dist = prevAln.getDistance(currAln);
maxDist = std::max(maxDist, dist);
minDist = std::min(minDist, dist);
distMap.insert({{i,j}, dist});
distMatrix[getDistMatrixIndex(i, j)] = dist;
}
}
eps = minDist;
Expand All @@ -317,7 +328,7 @@ class DBSCANCluster {
if (neighborIdx == centerIdx)
continue;

if (distMap[{std::min(centerIdx, neighborIdx), std::max(centerIdx, neighborIdx)}] >= eps)
if (distMatrix[getDistMatrixIndex(centerIdx, neighborIdx)] >= eps)
continue;

neighborVec.emplace_back(neighborIdx);
Expand Down Expand Up @@ -374,7 +385,7 @@ class DBSCANCluster {
dbBestTmScore.clear();
qFoundChainKeys.clear();
dbFoundChainKeys.clear();
distMap.clear();
distMatrix.clear();
return !finalClusters.empty();
}

Expand Down Expand Up @@ -425,7 +436,7 @@ class DBSCANCluster {
for (auto neighborIdx: neighbors) {
if (neighborIdx == centerIdx)
continue;
neighborsWithDist.emplace_back(neighborIdx, distMap[{std::min(centerIdx, neighborIdx), std::max(centerIdx, neighborIdx)}]);
neighborsWithDist.emplace_back(neighborIdx, distMatrix[getDistMatrixIndex(centerIdx, neighborIdx)]);
}
SORT_SERIAL(neighborsWithDist.begin(), neighborsWithDist.end(), compareNeighborWithDist);
neighbors.clear();
Expand Down Expand Up @@ -533,7 +544,7 @@ class ComplexScorer {
tmAligner = new TMaligner(maxResLen, false, true, false);
}
finalClusters.clear();
DBSCANCluster dbscanCluster = DBSCANCluster(searchResult, finalClusters, minAssignedChainsRatio);
DBSCANCluster dbscanCluster(searchResult, finalClusters, minAssignedChainsRatio);
if (!dbscanCluster.getAlnClusters()) {
finalClusters.clear();
return;
Expand Down

0 comments on commit be9fc33

Please sign in to comment.