1010#include < omp.h>
1111using namespace std ;
1212
13- vector<int > HNSWGraph::searchLayer (Item& q, int ep , int ef, int lc) {
13+ vector<int > HNSWGraph::searchLayer (Item& q, int entry_point , int ef, int lc) {
1414 set<pair<double , int >> candidates;
1515 set<pair<double , int >> nearestNeighbors;
1616 unordered_set<int > isVisited;
1717
18- double td = q.dist (items[ep ]);
19- candidates.insert (make_pair (td, ep ));
20- nearestNeighbors.insert (make_pair (td, ep ));
21- isVisited.insert (ep );
18+ double td = q.dist (items[entry_point ]);
19+ candidates.insert (make_pair (td, entry_point ));
20+ nearestNeighbors.insert (make_pair (td, entry_point ));
21+ isVisited.insert (entry_point );
2222 while (!candidates.empty ()) {
23- auto ci = candidates.begin (); candidates.erase (candidates.begin ());
23+ auto ci = candidates.begin ();
24+ candidates.erase (candidates.begin ());
2425 int nid = ci->second ;
2526 auto fi = nearestNeighbors.end (); fi--;
26- if (ci->first > fi->first ) break ;
27+ if (ci->first > fi->first ) {
28+ break ;
29+ }
2730 for (int ed: layerEdgeLists[lc][nid]) {
28- if (isVisited.find (ed) != isVisited.end ()) continue ;
29- fi = nearestNeighbors.end (); fi--;
31+ if (isVisited.find (ed) != isVisited.end ()) {
32+ continue ;
33+ }
34+ fi = nearestNeighbors.end ();
35+ fi--;
3036 isVisited.insert (ed);
3137 td = q.dist (items[ed]);
3238 if ((td < fi->first ) || nearestNeighbors.size () < ef) {
3339 candidates.insert (make_pair (td, ed));
3440 nearestNeighbors.insert (make_pair (td, ed));
35- if (nearestNeighbors.size () > ef) nearestNeighbors.erase (fi);
41+ if (nearestNeighbors.size () > ef) {
42+ nearestNeighbors.erase (fi);
43+ }
3644 }
3745 }
3846 }
@@ -43,50 +51,65 @@ vector<int> HNSWGraph::searchLayer(Item& q, int ep, int ef, int lc) {
4351
4452vector<int > HNSWGraph::KNNSearch (Item& q, int K) {
4553 int maxLyer = layerEdgeLists.size () - 1 ;
46- int ep = enterNode;
47- for (int l = maxLyer; l >= 1 ; l--) ep = searchLayer (q, ep, 1 , l)[0 ];
48- return searchLayer (q, ep, K, 0 );
54+ int entry_point = enterNode;
55+ for (int l = maxLyer; l >= 1 ; l--) {
56+ entry_point = searchLayer (q, entry_point, 1 , l)[0 ];
57+ }
58+ return searchLayer (q, entry_point, K, 0 );
4959}
5060
5161void HNSWGraph::addEdge (int st, int ed, int lc) {
52- if (st == ed) return ;
62+ if (st == ed) {
63+ return ;
64+ }
5365 layerEdgeLists[lc][st].push_back (ed);
5466 layerEdgeLists[lc][ed].push_back (st);
5567}
5668
5769void HNSWGraph::Insert (Item& q) {
5870 int nid = items.size ();
5971 itemNum++; items.push_back (q);
60- // sample layer
6172 int maxLyer = layerEdgeLists.size () - 1 ;
6273 int l = 0 ;
6374 uniform_real_distribution<double > distribution (0.0 ,1.0 );
6475 while (l < ml && (1.0 / ml <= distribution (generator))) {
6576 l++;
66- if (layerEdgeLists.size () <= l) layerEdgeLists.push_back (unordered_map<int , vector<int >>());
77+ if (layerEdgeLists.size () <= l) {
78+ layerEdgeLists.push_back (unordered_map<int , vector<int >>());
79+ }
6780 }
6881 if (nid == 0 ) {
6982 enterNode = nid;
7083 return ;
7184 }
72- // search up layer entrance
73- int ep = enterNode;
74- for (int i = maxLyer; i > l; i--) ep = searchLayer (q, ep, 1 , i)[0 ];
85+ int entry_point = enterNode;
86+ for (int i = maxLyer; i > l; i--) {
87+ entry_point = searchLayer (q, entry_point, 1 , i)[0 ];
88+ }
89+ #pragma omp parallel for
7590 for (int i = min (l, maxLyer); i >= 0 ; i--) {
7691 int MM = l == 0 ? MMax0 : MMax;
77- vector<int > neighbors = searchLayer (q, ep , efConstruction, i);
92+ vector<int > neighbors = searchLayer (q, entry_point , efConstruction, i);
7893 vector<int > selectedNeighbors = vector<int >(neighbors.begin (), neighbors.begin ()+min (int (neighbors.size ()), M));
79- for (int n: selectedNeighbors) addEdge (n, nid, i);
94+ for (int n: selectedNeighbors) {
95+ addEdge (n, nid, i);
96+ }
8097 for (int n: selectedNeighbors) {
8198 if (layerEdgeLists[i][n].size () > MM) {
8299 vector<pair<double , int >> distPairs;
83- for (int nn: layerEdgeLists[i][n]) distPairs.emplace_back (items[n].dist (items[nn]), nn);
100+ for (int nn: layerEdgeLists[i][n]) {
101+ distPairs.emplace_back (items[n].dist (items[nn]), nn);
102+ }
84103 sort (distPairs.begin (), distPairs.end ());
85104 layerEdgeLists[i][n].clear ();
86- for (int d = 0 ; d < min (int (distPairs.size ()), MM); d++) layerEdgeLists[i][n].push_back (distPairs[d].second );
105+ for (int d = 0 ; d < min (int (distPairs.size ()), MM); d++) {
106+ layerEdgeLists[i][n].push_back (distPairs[d].second );
107+ }
87108 }
88109 }
89- ep = selectedNeighbors[0 ];
110+ entry_point = selectedNeighbors[0 ];
111+ }
112+ if (l == layerEdgeLists.size () - 1 ) {
113+ enterNode = nid;
90114 }
91- if (l == layerEdgeLists.size () - 1 ) enterNode = nid;
92115}
0 commit comments