Skip to content

Commit bb36184

Browse files
authored
Merge pull request #2 from CSCI-739/initial-draft
Initial draft merge 2
2 parents 2013f80 + cb3f618 commit bb36184

File tree

10 files changed

+115
-82
lines changed

10 files changed

+115
-82
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
*.exe
2-
*.in
2+
*.in
3+
*.out
4+
*.o

hnsw.o

-662 KB
Binary file not shown.
File renamed without changes.
File renamed without changes.

main.cpp

Lines changed: 110 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,124 @@
44
#include <ctime>
55
#include <iostream>
66
#include <random>
7+
#include <fstream>
78
#include <vector>
9+
#include <memory>
10+
#include <string>
11+
#include <sstream>
812
using namespace std;
913

10-
void randomTest(int numItems, int dim, int numQueries, int K) {
11-
std::default_random_engine generator;
12-
std::uniform_real_distribution<double> distribution(0.0, 1.0);
13-
std::vector<Item> randomItems;
14-
randomItems.reserve(numItems);
14+
void readInputFromFile(const string& filename, int& D, int& N, int& M, vector<Item>& base, vector<Item>& queries) {
15+
ifstream infile(filename);
16+
if (!infile.is_open()) {
17+
cerr << "Unable to open file " << filename << endl;
18+
exit(EXIT_FAILURE);
19+
}
20+
21+
string firstLine;
22+
getline(infile, firstLine);
23+
24+
istringstream iss(firstLine);
25+
if (!(iss >> D >> N >> M) || iss.rdbuf()->in_avail() > 0) {
26+
cerr << "Invalid format in the input file " << filename << endl;
27+
exit(EXIT_FAILURE);
28+
}
29+
30+
base.reserve(N);
31+
queries.reserve(M);
32+
33+
for (int i = 0; i < N; ++i) {
34+
vector<double> temp(D);
35+
for (int j = 0; j < D; ++j) {
36+
if (!(infile >> temp[j])) {
37+
cerr << "Invalid format in the input file " << filename << " at line " << (i + 2) << endl;
38+
exit(EXIT_FAILURE);
39+
}
40+
}
41+
base.emplace_back(temp);
42+
}
1543

16-
for (int i = 0; i < numItems; i++) {
17-
std::vector<double> temp(dim);
18-
for (int d = 0; d < dim; d++) {
19-
temp[d] = distribution(generator);
44+
for (int i = 0; i < M; ++i) {
45+
vector<double> temp(D);
46+
for (int j = 0; j < D; ++j) {
47+
if (!(infile >> temp[j])) {
48+
cerr << "Invalid format in the input file " << filename << " at line " << (N + i + 2) << endl;
49+
exit(EXIT_FAILURE);
50+
}
2051
}
21-
randomItems.emplace_back(temp);
52+
queries.emplace_back(temp);
2253
}
2354

24-
std::shuffle(randomItems.begin(), randomItems.end(), generator);
25-
HNSWGraph myHNSWGraph(10, 30, 30, 10, 2);
26-
for (int i = 0; i < numItems; i++) {
27-
if (i % 10000 == 0) cout << i << endl;
28-
myHNSWGraph.Insert(randomItems[i]);
29-
}
30-
31-
double total_brute_force_time = 0.0;
32-
double total_hnsw_time = 0.0;
33-
34-
cout << "START QUERY" << endl;
35-
int numHits = 0;
36-
for (int i = 0; i < numQueries; i++) {
37-
vector<double> temp(dim);
38-
for (int d = 0; d < dim; d++) {
39-
temp[d] = distribution(generator);
40-
}
41-
Item query(temp);
42-
43-
// Brute force
44-
clock_t begin_time = clock();
45-
vector<pair<double, int>> distPairs;
46-
for (int j = 0; j < numItems; j++) {
47-
if (j == i) continue;
48-
distPairs.emplace_back(query.dist(randomItems[j]), j);
49-
}
50-
sort(distPairs.begin(), distPairs.end());
51-
total_brute_force_time += double( clock () - begin_time ) / CLOCKS_PER_SEC;
52-
53-
begin_time = clock();
54-
vector<int> knns = myHNSWGraph.KNNSearch(query, K);
55-
// cout << "Printing vectors";
56-
// std::cout << "Contents of knns vector:" << std::endl;
57-
// for (size_t i = 0; i < knns.size(); ++i) {
58-
// std::cout << "knns[" << i << "] = " << knns[i] << std::endl;
59-
// }
60-
// cout << "\nPrinted Vectors";
61-
total_hnsw_time += double( clock () - begin_time ) / CLOCKS_PER_SEC;
62-
63-
if (knns[0] == distPairs[0].second) numHits++;
64-
}
65-
cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
55+
infile.close();
6656
}
6757

68-
int main() {
69-
randomTest(10000, 4, 100, 5);
58+
59+
int main(int argc, char* argv[]) {
60+
61+
if (argc != 3) {
62+
cerr << "Usage: " << argv[0] << " <input_filename> <output_filename>" << endl;
63+
return EXIT_FAILURE;
64+
}
65+
66+
string filename = argv[1];
67+
string outputFilename = argv[2];
68+
int K = 5;
69+
70+
ifstream infile(filename);
71+
if (!infile.is_open()) {
72+
cerr << "Unable to open file " << filename << endl;
73+
exit(EXIT_FAILURE);
74+
}
75+
76+
int N = 0, D = 0, M = 0;
77+
78+
std::vector<Item> base, queries;
79+
80+
readInputFromFile(filename, D, N, M, base, queries);
81+
82+
HNSWGraph myHNSWGraph(10, 30, 30, 10, 2);
83+
84+
for (int i = 0; i < N; ++i) {
85+
myHNSWGraph.Insert(base[i]);
86+
}
87+
88+
double total_brute_force_time = 0.0;
89+
double total_hnsw_time = 0.0;
90+
91+
int numHits = 0;
92+
ofstream outfile(outputFilename);
93+
94+
if (!outfile.is_open()) {
95+
cerr << "Unable to open output file " << outputFilename << endl;
96+
exit(EXIT_FAILURE);
97+
}
98+
99+
for (int i = 0; i < M; ++i) {
100+
Item query = queries[i];
101+
clock_t begin_time = clock();
102+
vector<pair<double, int>> distPairs;
103+
for (int j = 0; j < N; ++j) {
104+
if (j == i) continue;
105+
distPairs.emplace_back(query.dist(base[j]), j);
106+
}
107+
sort(distPairs.begin(), distPairs.end());
108+
total_brute_force_time += double(clock() - begin_time) / CLOCKS_PER_SEC;
109+
110+
begin_time = clock();
111+
vector<int> knns = myHNSWGraph.KNNSearch(query, K);
112+
for (size_t idx = 0; idx < knns.size(); ++idx) {
113+
outfile << knns[idx];
114+
if (idx != knns.size() - 1) {
115+
outfile << " ";
116+
}
117+
}
118+
outfile << endl;
119+
total_hnsw_time += double(clock() - begin_time) / CLOCKS_PER_SEC;
120+
121+
if (knns[0] == distPairs[0].second) numHits++;
122+
}
123+
outfile.close();
124+
cout << numHits << " " << total_brute_force_time / M << " " << total_hnsw_time / M << endl;
125+
70126
return 0;
71127
}

main.o

-405 KB
Binary file not shown.

my_program.exe

-481 KB
Binary file not shown.

run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
./my_program
1+
./my_program $1 $2

sample1.in

Lines changed: 0 additions & 16 deletions
This file was deleted.

sample_inputs/input_generator.cc

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void dump_func(int cas,int D,int N,int M,double T){
2727
printf("dumping case %d\n",cas);
2828
std::string file_name = "sample" + std::to_string(cas) + ".in";
2929
FILE* fp = fopen(file_name.c_str(),"w");
30-
fprintf(fp,"%d %d %d %f\n",D,N,M,T);
30+
fprintf(fp,"%d %d %d\n",D,N,M);
3131
std::vector<std::vector<double>> data;
3232
for(int i = 0;i < N;++i){
3333
std::vector<double> tmp;
@@ -67,19 +67,10 @@ void dump_func(int cas,int D,int N,int M,double T){
6767
}
6868

6969
int main(){
70-
// dump_func(12,128,10000, 10000,0.9);
7170
// Dimensions, Base Vectors, Queries, Targets
7271
dump_func(1,2,10,5,0.9);
7372
dump_func(2,4,50,10,0.9);
7473
dump_func(3,4,1000,100,0.9);
75-
// dump_func(3,128,1000000,10000,0.9);
76-
// dump_func(4,128,1000000,10000,0.95);
77-
// dump_func(5,256,1000000,10000,0.9);
78-
// dump_func(6,256,1000000,10000,0.95);
79-
// dump_func(7,512,500000,10000,0.9);
80-
// dump_func(8,512,500000,10000,0.95);
81-
// dump_func(9,512,1000000,10000,0.9);
82-
// dump_func(10,512,1000000,10000,0.95);
8374
return 0;
8475
}
8576

0 commit comments

Comments
 (0)