Skip to content

Commit eb02f59

Browse files
committed
Added gitignore and sample input
1 parent f8768c4 commit eb02f59

File tree

4 files changed

+159
-0
lines changed

4 files changed

+159
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.exe
2+
*.in

main1.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#include "hnsw.h"
2+
3+
#include <algorithm>
4+
#include <ctime>
5+
#include <iostream>
6+
#include <random>
7+
#include <vector>
8+
using namespace std;
9+
10+
void randomTest(int numItems, int dim, int numQueries, int K) {
11+
std::default_random_engine generator;
12+
std::uniform_real_distribution<double> distribution(0.0, 1.0);
13+
std::vector<Item> randomItems;
14+
randomItems.reserve(numItems);
15+
16+
for (int i = 0; i < numItems; i++) {
17+
std::vector<double> temp(dim);
18+
for (int d = 0; d < dim; d++) {
19+
temp[d] = distribution(generator);
20+
}
21+
randomItems.emplace_back(temp);
22+
}
23+
24+
std::shuffle(randomItems.begin(), randomItems.end(), generator);
25+
HNSWGraph myHNSWGraph(10, 30, 30, 10, 2);
26+
for (int i = 0; i < numItems; i++) {
27+
if (i % 10000 == 0) cout << i << endl;
28+
myHNSWGraph.Insert(randomItems[i]);
29+
}
30+
31+
double total_brute_force_time = 0.0;
32+
double total_hnsw_time = 0.0;
33+
34+
cout << "START QUERY" << endl;
35+
int numHits = 0;
36+
for (int i = 0; i < numQueries; i++) {
37+
vector<double> temp(dim);
38+
for (int d = 0; d < dim; d++) {
39+
temp[d] = distribution(generator);
40+
}
41+
Item query(temp);
42+
43+
// Brute force
44+
clock_t begin_time = clock();
45+
vector<pair<double, int>> distPairs;
46+
for (int j = 0; j < numItems; j++) {
47+
if (j == i) continue;
48+
distPairs.emplace_back(query.dist(randomItems[j]), j);
49+
}
50+
sort(distPairs.begin(), distPairs.end());
51+
total_brute_force_time += double( clock () - begin_time ) / CLOCKS_PER_SEC;
52+
53+
begin_time = clock();
54+
vector<int> knns = myHNSWGraph.KNNSearch(query, K);
55+
// cout << "Printing vectors";
56+
// std::cout << "Contents of knns vector:" << std::endl;
57+
// for (size_t i = 0; i < knns.size(); ++i) {
58+
// std::cout << "knns[" << i << "] = " << knns[i] << std::endl;
59+
// }
60+
// cout << "\nPrinted Vectors";
61+
total_hnsw_time += double( clock () - begin_time ) / CLOCKS_PER_SEC;
62+
63+
if (knns[0] == distPairs[0].second) numHits++;
64+
}
65+
cout << numHits << " " << total_brute_force_time / numQueries << " " << total_hnsw_time / numQueries << endl;
66+
}
67+
68+
int main() {
69+
randomTest(10000, 4, 100, 5);
70+
return 0;
71+
}

sample_inputs/compile_generator.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
g++ input_generator.cc -o input_generator -O3 -std=c++11

sample_inputs/input_generator.cc

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include<stdio.h>
2+
#include<string.h>
3+
#include<random>
4+
#include<string>
5+
#include<algorithm>
6+
#include<set>
7+
#include<random>
8+
#include<cmath>
9+
10+
std::mt19937 generator(123456789);
11+
std::normal_distribution<double> dist(0.0,1.0);
12+
const int max_iter = 100;
13+
14+
double cosine_similarity(const std::vector<double>& u,const std::vector<double>& v){
15+
double ip = 0;
16+
double sumu2 = 0;
17+
double sumv2 = 0;
18+
for(int i = 0;i < u.size();++i){
19+
ip += u[i] * v[i];
20+
sumu2 += u[i] * u[i];
21+
sumv2 += v[i] * v[i];
22+
}
23+
return ip / sqrt(sumu2) / sqrt(sumv2);
24+
}
25+
26+
void dump_func(int cas,int D,int N,int M,double T){
27+
printf("dumping case %d\n",cas);
28+
std::string file_name = "sample" + std::to_string(cas) + ".in";
29+
FILE* fp = fopen(file_name.c_str(),"w");
30+
fprintf(fp,"%d %d %d %f\n",D,N,M,T);
31+
std::vector<std::vector<double>> data;
32+
for(int i = 0;i < N;++i){
33+
std::vector<double> tmp;
34+
for(int j = 0;j < D;++j)
35+
tmp.push_back(dist(generator));
36+
data.push_back(tmp);
37+
fprintf(fp,"%f",tmp[0]);
38+
for(int j = 1;j < D;++j)
39+
fprintf(fp," %f",tmp[j]);
40+
fprintf(fp,"\n");
41+
}
42+
43+
std::uniform_int_distribution<int> distint(0,N - 1);
44+
for(int i = 0;i < M;++i){
45+
double eps = 1e-1;
46+
bool found = false;
47+
while(!found){
48+
std::normal_distribution<double> disteps(0.0,eps);
49+
for(int iter = 0;iter < max_iter;++iter){
50+
int idx = distint(generator);
51+
std::vector<double> q = data[idx];
52+
for(int j = 0;j < D;++j)
53+
q[j] += disteps(generator);
54+
if(cosine_similarity(q,data[idx]) > T){
55+
found = true;
56+
fprintf(fp,"%f",q[0]);
57+
for(int j = 1;j < D;++j)
58+
fprintf(fp," %f",q[j]);
59+
fprintf(fp,"\n");
60+
break;
61+
}
62+
}
63+
eps /= 2;
64+
}
65+
}
66+
fclose(fp);
67+
}
68+
69+
int main(){
70+
// dump_func(12,128,10000, 10000,0.9);
71+
// Dimensions, Base Vectors, Queries, Targets
72+
dump_func(1,2,10,5,0.9);
73+
dump_func(2,4,50,10,0.9);
74+
dump_func(3,4,1000,100,0.9);
75+
// dump_func(3,128,1000000,10000,0.9);
76+
// dump_func(4,128,1000000,10000,0.95);
77+
// dump_func(5,256,1000000,10000,0.9);
78+
// dump_func(6,256,1000000,10000,0.95);
79+
// dump_func(7,512,500000,10000,0.9);
80+
// dump_func(8,512,500000,10000,0.95);
81+
// dump_func(9,512,1000000,10000,0.9);
82+
// dump_func(10,512,1000000,10000,0.95);
83+
return 0;
84+
}
85+

0 commit comments

Comments
 (0)