1111#include < sstream>
1212using namespace std ;
1313
14- void readInputFromFile (const string& filename, int & dim , int & numItems , int & numQueries , vector<Item>& randomItems , vector<Item>& queries) {
14+ void readInputFromFile (const string& filename, int & D , int & N , int & M , vector<Item>& base , vector<Item>& queries) {
1515 ifstream infile (filename);
1616 if (!infile.is_open ()) {
1717 cerr << " Unable to open file " << filename << endl;
@@ -22,32 +22,30 @@ void readInputFromFile(const string& filename, int& dim, int& numItems, int& num
2222 getline (infile, firstLine);
2323
2424 istringstream iss (firstLine);
25- if (!(iss >> dim >> numItems >> numQueries ) || iss.rdbuf ()->in_avail () > 0 ) {
25+ if (!(iss >> D >> N >> M ) || iss.rdbuf ()->in_avail () > 0 ) {
2626 cerr << " Invalid format in the input file " << filename << endl;
2727 exit (EXIT_FAILURE);
2828 }
2929
30- randomItems .reserve (numItems );
31- queries.reserve (numQueries );
30+ base .reserve (N );
31+ queries.reserve (M );
3232
33- // Read base vectors
34- for (int i = 0 ; i < numItems; ++i) {
35- vector<double > temp (dim);
36- for (int j = 0 ; j < dim; ++j) {
33+ for (int i = 0 ; i < N; ++i) {
34+ vector<double > temp (D);
35+ for (int j = 0 ; j < D; ++j) {
3736 if (!(infile >> temp[j])) {
3837 cerr << " Invalid format in the input file " << filename << " at line " << (i + 2 ) << endl;
3938 exit (EXIT_FAILURE);
4039 }
4140 }
42- randomItems .emplace_back (temp);
41+ base .emplace_back (temp);
4342 }
4443
45- // Read queries
46- for (int i = 0 ; i < numQueries; ++i) {
47- vector<double > temp (dim);
48- for (int j = 0 ; j < dim; ++j) {
44+ for (int i = 0 ; i < M; ++i) {
45+ vector<double > temp (D);
46+ for (int j = 0 ; j < D; ++j) {
4947 if (!(infile >> temp[j])) {
50- cerr << " Invalid format in the input file " << filename << " at line " << (numItems + i + 2 ) << endl;
48+ cerr << " Invalid format in the input file " << filename << " at line " << (N + i + 2 ) << endl;
5149 exit (EXIT_FAILURE);
5250 }
5351 }
@@ -60,12 +58,13 @@ void readInputFromFile(const string& filename, int& dim, int& numItems, int& num
6058
6159int main (int argc, char * argv[]) {
6260
63- if (argc != 2 ) {
64- cerr << " Usage: " << argv[0 ] << " <filename >" << endl;
61+ if (argc != 3 ) {
62+ cerr << " Usage: " << argv[0 ] << " <input_filename> <output_filename >" << endl;
6563 return EXIT_FAILURE;
6664 }
6765
6866 string filename = argv[1 ];
67+ string outputFilename = argv[2 ];
6968 int K = 5 ;
7069
7170 ifstream infile (filename);
@@ -74,43 +73,55 @@ int main(int argc, char* argv[]) {
7473 exit (EXIT_FAILURE);
7574 }
7675
77- int numItems = 0 , dim = 0 , numQueries = 0 ;
76+ int N = 0 , D = 0 , M = 0 ;
7877
79- std::vector<Item> randomItems , queries;
78+ std::vector<Item> base , queries;
8079
81- readInputFromFile (filename, dim, numItems, numQueries, randomItems , queries);
80+ readInputFromFile (filename, D, N, M, base , queries);
8281
8382 HNSWGraph myHNSWGraph (10 , 30 , 30 , 10 , 2 );
8483
85- for (int i = 0 ; i < numItems ; ++i) {
86- myHNSWGraph.Insert (randomItems [i]);
84+ for (int i = 0 ; i < N ; ++i) {
85+ myHNSWGraph.Insert (base [i]);
8786 }
8887
8988 double total_brute_force_time = 0.0 ;
9089 double total_hnsw_time = 0.0 ;
9190
9291 int numHits = 0 ;
92+ ofstream outfile (outputFilename);
9393
94- for (int i = 0 ; i < numQueries; ++i) {
95- Item query = queries[i];
94+ if (!outfile.is_open ()) {
95+ cerr << " Unable to open output file " << outputFilename << endl;
96+ exit (EXIT_FAILURE);
97+ }
9698
97- // Brute force
99+ for (int i = 0 ; i < M; ++i) {
100+ Item query = queries[i];
98101 clock_t begin_time = clock ();
99102 vector<pair<double , int >> distPairs;
100- for (int j = 0 ; j < numItems ; ++j) {
103+ for (int j = 0 ; j < N ; ++j) {
101104 if (j == i) continue ;
102- distPairs.emplace_back (query.dist (randomItems [j]), j);
105+ distPairs.emplace_back (query.dist (base [j]), j);
103106 }
104107 sort (distPairs.begin (), distPairs.end ());
105108 total_brute_force_time += double (clock () - begin_time) / CLOCKS_PER_SEC;
106109
107110 begin_time = clock ();
108111 vector<int > knns = myHNSWGraph.KNNSearch (query, K);
112+ for (size_t idx = 0 ; idx < knns.size (); ++idx) {
113+ outfile << knns[idx];
114+ if (idx != knns.size () - 1 ) {
115+ outfile << " " ;
116+ }
117+ }
118+ outfile << endl;
109119 total_hnsw_time += double (clock () - begin_time) / CLOCKS_PER_SEC;
110120
111121 if (knns[0 ] == distPairs[0 ].second ) numHits++;
112122 }
113- cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
123+ outfile.close ();
124+ cout << numHits << " " << total_brute_force_time / M << " " << total_hnsw_time / M << endl;
114125
115126 return 0 ;
116127}
0 commit comments