44#include < ctime>
55#include < iostream>
66#include < random>
7+ #include < fstream>
78#include < vector>
9+ #include < memory>
10+ #include < string>
11+ #include < sstream>
812using namespace std ;
913
10- void randomTest (int numItems, int dim, int numQueries, int K) {
11- std::default_random_engine generator;
12- std::uniform_real_distribution<double > distribution (0.0 , 1.0 );
13- std::vector<Item> randomItems;
14- randomItems.reserve (numItems);
14+ void readInputFromFile (const string& filename, int & D, int & N, int & M, vector<Item>& base, vector<Item>& queries) {
15+ ifstream infile (filename);
16+ if (!infile.is_open ()) {
17+ cerr << " Unable to open file " << filename << endl;
18+ exit (EXIT_FAILURE);
19+ }
20+
21+ string firstLine;
22+ getline (infile, firstLine);
23+
24+ istringstream iss (firstLine);
25+ if (!(iss >> D >> N >> M) || iss.rdbuf ()->in_avail () > 0 ) {
26+ cerr << " Invalid format in the input file " << filename << endl;
27+ exit (EXIT_FAILURE);
28+ }
29+
30+ base.reserve (N);
31+ queries.reserve (M);
32+
33+ for (int i = 0 ; i < N; ++i) {
34+ vector<double > temp (D);
35+ for (int j = 0 ; j < D; ++j) {
36+ if (!(infile >> temp[j])) {
37+ cerr << " Invalid format in the input file " << filename << " at line " << (i + 2 ) << endl;
38+ exit (EXIT_FAILURE);
39+ }
40+ }
41+ base.emplace_back (temp);
42+ }
1543
16- for (int i = 0 ; i < numItems; i++) {
17- std::vector<double > temp (dim);
18- for (int d = 0 ; d < dim; d++) {
19- temp[d] = distribution (generator);
44+ for (int i = 0 ; i < M; ++i) {
45+ vector<double > temp (D);
46+ for (int j = 0 ; j < D; ++j) {
47+ if (!(infile >> temp[j])) {
48+ cerr << " Invalid format in the input file " << filename << " at line " << (N + i + 2 ) << endl;
49+ exit (EXIT_FAILURE);
50+ }
2051 }
21- randomItems .emplace_back (temp);
52+ queries .emplace_back (temp);
2253 }
2354
24- std::shuffle (randomItems.begin (), randomItems.end (), generator);
25- HNSWGraph myHNSWGraph (10 , 30 , 30 , 10 , 2 );
26- for (int i = 0 ; i < numItems; i++) {
27- if (i % 10000 == 0 ) cout << i << endl;
28- myHNSWGraph.Insert (randomItems[i]);
29- }
30-
31- double total_brute_force_time = 0.0 ;
32- double total_hnsw_time = 0.0 ;
33-
34- cout << " START QUERY" << endl;
35- int numHits = 0 ;
36- for (int i = 0 ; i < numQueries; i++) {
37- vector<double > temp (dim);
38- for (int d = 0 ; d < dim; d++) {
39- temp[d] = distribution (generator);
40- }
41- Item query (temp);
42-
43- // Brute force
44- clock_t begin_time = clock ();
45- vector<pair<double , int >> distPairs;
46- for (int j = 0 ; j < numItems; j++) {
47- if (j == i) continue ;
48- distPairs.emplace_back (query.dist (randomItems[j]), j);
49- }
50- sort (distPairs.begin (), distPairs.end ());
51- total_brute_force_time += double ( clock () - begin_time ) / CLOCKS_PER_SEC;
52-
53- begin_time = clock ();
54- vector<int > knns = myHNSWGraph.KNNSearch (query, K);
55- // cout << "Printing vectors";
56- // std::cout << "Contents of knns vector:" << std::endl;
57- // for (size_t i = 0; i < knns.size(); ++i) {
58- // std::cout << "knns[" << i << "] = " << knns[i] << std::endl;
59- // }
60- // cout << "\nPrinted Vectors";
61- total_hnsw_time += double ( clock () - begin_time ) / CLOCKS_PER_SEC;
62-
63- if (knns[0 ] == distPairs[0 ].second ) numHits++;
64- }
65- cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
55+ infile.close ();
6656}
6757
68- int main () {
69- randomTest (10000 , 4 , 100 , 5 );
58+
59+ int main (int argc, char * argv[]) {
60+
61+ if (argc != 3 ) {
62+ cerr << " Usage: " << argv[0 ] << " <input_filename> <output_filename>" << endl;
63+ return EXIT_FAILURE;
64+ }
65+
66+ string filename = argv[1 ];
67+ string outputFilename = argv[2 ];
68+ int K = 5 ;
69+
70+ ifstream infile (filename);
71+ if (!infile.is_open ()) {
72+ cerr << " Unable to open file " << filename << endl;
73+ exit (EXIT_FAILURE);
74+ }
75+
76+ int N = 0 , D = 0 , M = 0 ;
77+
78+ std::vector<Item> base, queries;
79+
80+ readInputFromFile (filename, D, N, M, base, queries);
81+
82+ HNSWGraph myHNSWGraph (10 , 30 , 30 , 10 , 2 );
83+
84+ for (int i = 0 ; i < N; ++i) {
85+ myHNSWGraph.Insert (base[i]);
86+ }
87+
88+ double total_brute_force_time = 0.0 ;
89+ double total_hnsw_time = 0.0 ;
90+
91+ int numHits = 0 ;
92+ ofstream outfile (outputFilename);
93+
94+ if (!outfile.is_open ()) {
95+ cerr << " Unable to open output file " << outputFilename << endl;
96+ exit (EXIT_FAILURE);
97+ }
98+
99+ for (int i = 0 ; i < M; ++i) {
100+ Item query = queries[i];
101+ clock_t begin_time = clock ();
102+ vector<pair<double , int >> distPairs;
103+ for (int j = 0 ; j < N; ++j) {
104+ if (j == i) continue ;
105+ distPairs.emplace_back (query.dist (base[j]), j);
106+ }
107+ sort (distPairs.begin (), distPairs.end ());
108+ total_brute_force_time += double (clock () - begin_time) / CLOCKS_PER_SEC;
109+
110+ begin_time = clock ();
111+ vector<int > knns = myHNSWGraph.KNNSearch (query, K);
112+ for (size_t idx = 0 ; idx < knns.size (); ++idx) {
113+ outfile << knns[idx];
114+ if (idx != knns.size () - 1 ) {
115+ outfile << " " ;
116+ }
117+ }
118+ outfile << endl;
119+ total_hnsw_time += double (clock () - begin_time) / CLOCKS_PER_SEC;
120+
121+ if (knns[0 ] == distPairs[0 ].second ) numHits++;
122+ }
123+ outfile.close ();
124+ cout << numHits << " " << total_brute_force_time / M << " " << total_hnsw_time / M << endl;
125+
70126 return 0 ;
71127}
0 commit comments