Skip to content

Commit ecc7a5f

Browse files
committed
Illegal input handling
1 parent fb45eb0 commit ecc7a5f

File tree

4 files changed

+98
-53
lines changed

4 files changed

+98
-53
lines changed

main.cpp

Lines changed: 97 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,113 @@
44
#include <ctime>
55
#include <iostream>
66
#include <random>
7+
#include <fstream>
78
#include <vector>
9+
#include <memory>
10+
#include <string>
11+
#include <sstream>
812
using namespace std;
913

10-
void randomTest(int numItems, int dim, int numQueries, int K) {
11-
std::default_random_engine generator;
12-
std::uniform_real_distribution<double> distribution(0.0, 1.0);
13-
std::vector<Item> randomItems;
14+
void readInputFromFile(const string& filename, int& dim, int& numItems, int& numQueries, vector<Item>& randomItems, vector<Item>& queries) {
15+
ifstream infile(filename);
16+
if (!infile.is_open()) {
17+
cerr << "Unable to open file " << filename << endl;
18+
exit(EXIT_FAILURE);
19+
}
20+
21+
string firstLine;
22+
getline(infile, firstLine);
23+
24+
istringstream iss(firstLine);
25+
if (!(iss >> dim >> numItems >> numQueries) || iss.rdbuf()->in_avail() > 0) {
26+
cerr << "Invalid format in the input file " << filename << endl;
27+
exit(EXIT_FAILURE);
28+
}
29+
1430
randomItems.reserve(numItems);
31+
queries.reserve(numQueries);
1532

16-
for (int i = 0; i < numItems; i++) {
17-
std::vector<double> temp(dim);
18-
for (int d = 0; d < dim; d++) {
19-
temp[d] = distribution(generator);
33+
// Read base vectors
34+
for (int i = 0; i < numItems; ++i) {
35+
vector<double> temp(dim);
36+
for (int j = 0; j < dim; ++j) {
37+
if (!(infile >> temp[j])) {
38+
cerr << "Invalid format in the input file " << filename << " at line " << (i + 2) << endl;
39+
exit(EXIT_FAILURE);
40+
}
2041
}
2142
randomItems.emplace_back(temp);
2243
}
2344

24-
std::shuffle(randomItems.begin(), randomItems.end(), generator);
25-
HNSWGraph myHNSWGraph(10, 30, 30, 10, 2);
26-
for (int i = 0; i < numItems; i++) {
27-
if (i % 10000 == 0) cout << i << endl;
28-
myHNSWGraph.Insert(randomItems[i]);
29-
}
30-
31-
double total_brute_force_time = 0.0;
32-
double total_hnsw_time = 0.0;
33-
34-
cout << "START QUERY" << endl;
35-
int numHits = 0;
36-
for (int i = 0; i < numQueries; i++) {
37-
vector<double> temp(dim);
38-
for (int d = 0; d < dim; d++) {
39-
temp[d] = distribution(generator);
40-
}
41-
Item query(temp);
42-
43-
// Brute force
44-
clock_t begin_time = clock();
45-
vector<pair<double, int>> distPairs;
46-
for (int j = 0; j < numItems; j++) {
47-
if (j == i) continue;
48-
distPairs.emplace_back(query.dist(randomItems[j]), j);
49-
}
50-
sort(distPairs.begin(), distPairs.end());
51-
total_brute_force_time += double( clock () - begin_time ) / CLOCKS_PER_SEC;
52-
53-
begin_time = clock();
54-
vector<int> knns = myHNSWGraph.KNNSearch(query, K);
55-
// cout << "Printing vectors";
56-
// std::cout << "Contents of knns vector:" << std::endl;
57-
// for (size_t i = 0; i < knns.size(); ++i) {
58-
// std::cout << "knns[" << i << "] = " << knns[i] << std::endl;
59-
// }
60-
// cout << "\nPrinted Vectors";
61-
total_hnsw_time += double( clock () - begin_time ) / CLOCKS_PER_SEC;
62-
63-
if (knns[0] == distPairs[0].second) numHits++;
64-
}
65-
cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
45+
// Read queries
46+
for (int i = 0; i < numQueries; ++i) {
47+
vector<double> temp(dim);
48+
for (int j = 0; j < dim; ++j) {
49+
if (!(infile >> temp[j])) {
50+
cerr << "Invalid format in the input file " << filename << " at line " << (numItems + i + 2) << endl;
51+
exit(EXIT_FAILURE);
52+
}
53+
}
54+
queries.emplace_back(temp);
55+
}
56+
57+
infile.close();
6658
}
6759

68-
int main() {
69-
randomTest(10000, 4, 100, 5);
60+
61+
int main(int argc, char* argv[]) {
62+
63+
if (argc != 2) {
64+
cerr << "Usage: " << argv[0] << " <filename>" << endl;
65+
return EXIT_FAILURE;
66+
}
67+
68+
string filename = argv[1];
69+
int K = 5;
70+
71+
ifstream infile(filename);
72+
if (!infile.is_open()) {
73+
cerr << "Unable to open file " << filename << endl;
74+
exit(EXIT_FAILURE);
75+
}
76+
77+
int numItems = 0, dim = 0, numQueries = 0;
78+
79+
std::vector<Item> randomItems, queries;
80+
81+
readInputFromFile(filename, dim, numItems, numQueries, randomItems, queries);
82+
83+
HNSWGraph myHNSWGraph(10, 30, 30, 10, 2);
84+
85+
for (int i = 0; i < numItems; ++i) {
86+
myHNSWGraph.Insert(randomItems[i]);
87+
}
88+
89+
double total_brute_force_time = 0.0;
90+
double total_hnsw_time = 0.0;
91+
92+
int numHits = 0;
93+
94+
for (int i = 0; i < numQueries; ++i) {
95+
Item query = queries[i];
96+
97+
// Brute force
98+
clock_t begin_time = clock();
99+
vector<pair<double, int>> distPairs;
100+
for (int j = 0; j < numItems; ++j) {
101+
if (j == i) continue;
102+
distPairs.emplace_back(query.dist(randomItems[j]), j);
103+
}
104+
sort(distPairs.begin(), distPairs.end());
105+
total_brute_force_time += double(clock() - begin_time) / CLOCKS_PER_SEC;
106+
107+
begin_time = clock();
108+
vector<int> knns = myHNSWGraph.KNNSearch(query, K);
109+
total_hnsw_time += double(clock() - begin_time) / CLOCKS_PER_SEC;
110+
111+
if (knns[0] == distPairs[0].second) numHits++;
112+
}
113+
cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
114+
70115
return 0;
71116
}

main.o

-35.3 KB
Binary file not shown.

my_program.exe

-8.2 KB
Binary file not shown.

run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
./my_program
1+
./my_program $1

0 commit comments

Comments
 (0)