44#include < ctime>
55#include < iostream>
66#include < random>
7+ #include < fstream>
78#include < vector>
9+ #include < memory>
10+ #include < string>
11+ #include < sstream>
812using namespace std ;
913
10- void randomTest (int numItems, int dim, int numQueries, int K) {
11- std::default_random_engine generator;
12- std::uniform_real_distribution<double > distribution (0.0 , 1.0 );
13- std::vector<Item> randomItems;
14+ void readInputFromFile (const string& filename, int & dim, int & numItems, int & numQueries, vector<Item>& randomItems, vector<Item>& queries) {
15+ ifstream infile (filename);
16+ if (!infile.is_open ()) {
17+ cerr << " Unable to open file " << filename << endl;
18+ exit (EXIT_FAILURE);
19+ }
20+
21+ string firstLine;
22+ getline (infile, firstLine);
23+
24+ istringstream iss (firstLine);
25+ if (!(iss >> dim >> numItems >> numQueries) || iss.rdbuf ()->in_avail () > 0 ) {
26+ cerr << " Invalid format in the input file " << filename << endl;
27+ exit (EXIT_FAILURE);
28+ }
29+
1430 randomItems.reserve (numItems);
31+ queries.reserve (numQueries);
1532
16- for (int i = 0 ; i < numItems; i++) {
17- std::vector<double > temp (dim);
18- for (int d = 0 ; d < dim; d++) {
19- temp[d] = distribution (generator);
33+ // Read base vectors
34+ for (int i = 0 ; i < numItems; ++i) {
35+ vector<double > temp (dim);
36+ for (int j = 0 ; j < dim; ++j) {
37+ if (!(infile >> temp[j])) {
38+ cerr << " Invalid format in the input file " << filename << " at line " << (i + 2 ) << endl;
39+ exit (EXIT_FAILURE);
40+ }
2041 }
2142 randomItems.emplace_back (temp);
2243 }
2344
24- std::shuffle (randomItems.begin (), randomItems.end (), generator);
25- HNSWGraph myHNSWGraph (10 , 30 , 30 , 10 , 2 );
26- for (int i = 0 ; i < numItems; i++) {
27- if (i % 10000 == 0 ) cout << i << endl;
28- myHNSWGraph.Insert (randomItems[i]);
29- }
30-
31- double total_brute_force_time = 0.0 ;
32- double total_hnsw_time = 0.0 ;
33-
34- cout << " START QUERY" << endl;
35- int numHits = 0 ;
36- for (int i = 0 ; i < numQueries; i++) {
37- vector<double > temp (dim);
38- for (int d = 0 ; d < dim; d++) {
39- temp[d] = distribution (generator);
40- }
41- Item query (temp);
42-
43- // Brute force
44- clock_t begin_time = clock ();
45- vector<pair<double , int >> distPairs;
46- for (int j = 0 ; j < numItems; j++) {
47- if (j == i) continue ;
48- distPairs.emplace_back (query.dist (randomItems[j]), j);
49- }
50- sort (distPairs.begin (), distPairs.end ());
51- total_brute_force_time += double ( clock () - begin_time ) / CLOCKS_PER_SEC;
52-
53- begin_time = clock ();
54- vector<int > knns = myHNSWGraph.KNNSearch (query, K);
55- // cout << "Printing vectors";
56- // std::cout << "Contents of knns vector:" << std::endl;
57- // for (size_t i = 0; i < knns.size(); ++i) {
58- // std::cout << "knns[" << i << "] = " << knns[i] << std::endl;
59- // }
60- // cout << "\nPrinted Vectors";
61- total_hnsw_time += double ( clock () - begin_time ) / CLOCKS_PER_SEC;
62-
63- if (knns[0 ] == distPairs[0 ].second ) numHits++;
64- }
65- cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
45+ // Read queries
46+ for (int i = 0 ; i < numQueries; ++i) {
47+ vector<double > temp (dim);
48+ for (int j = 0 ; j < dim; ++j) {
49+ if (!(infile >> temp[j])) {
50+ cerr << " Invalid format in the input file " << filename << " at line " << (numItems + i + 2 ) << endl;
51+ exit (EXIT_FAILURE);
52+ }
53+ }
54+ queries.emplace_back (temp);
55+ }
56+
57+ infile.close ();
6658}
6759
68- int main () {
69- randomTest (10000 , 4 , 100 , 5 );
60+
61+ int main (int argc, char * argv[]) {
62+
63+ if (argc != 2 ) {
64+ cerr << " Usage: " << argv[0 ] << " <filename>" << endl;
65+ return EXIT_FAILURE;
66+ }
67+
68+ string filename = argv[1 ];
69+ int K = 5 ;
70+
71+ ifstream infile (filename);
72+ if (!infile.is_open ()) {
73+ cerr << " Unable to open file " << filename << endl;
74+ exit (EXIT_FAILURE);
75+ }
76+
77+ int numItems = 0 , dim = 0 , numQueries = 0 ;
78+
79+ std::vector<Item> randomItems, queries;
80+
81+ readInputFromFile (filename, dim, numItems, numQueries, randomItems, queries);
82+
83+ HNSWGraph myHNSWGraph (10 , 30 , 30 , 10 , 2 );
84+
85+ for (int i = 0 ; i < numItems; ++i) {
86+ myHNSWGraph.Insert (randomItems[i]);
87+ }
88+
89+ double total_brute_force_time = 0.0 ;
90+ double total_hnsw_time = 0.0 ;
91+
92+ int numHits = 0 ;
93+
94+ for (int i = 0 ; i < numQueries; ++i) {
95+ Item query = queries[i];
96+
97+ // Brute force
98+ clock_t begin_time = clock ();
99+ vector<pair<double , int >> distPairs;
100+ for (int j = 0 ; j < numItems; ++j) {
101+ if (j == i) continue ;
102+ distPairs.emplace_back (query.dist (randomItems[j]), j);
103+ }
104+ sort (distPairs.begin (), distPairs.end ());
105+ total_brute_force_time += double (clock () - begin_time) / CLOCKS_PER_SEC;
106+
107+ begin_time = clock ();
108+ vector<int > knns = myHNSWGraph.KNNSearch (query, K);
109+ total_hnsw_time += double (clock () - begin_time) / CLOCKS_PER_SEC;
110+
111+ if (knns[0 ] == distPairs[0 ].second ) numHits++;
112+ }
113+ cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
114+
70115 return 0 ;
71116}
0 commit comments