Skip to content

Commit 916ba55

Browse files
Replace unordered_map with phmap in hetero_sample (#266)
1 parent ae22058 commit 916ba55

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
#include "utils.h"
44

5-
#include "parallel_hashmap/phmap.h"
6-
75
#ifdef _WIN32
86
#include <process.h>
97
#endif
@@ -142,21 +140,21 @@ hetero_sample(const vector<node_t> &node_types,
142140
const int64_t num_hops) {
143141

144142
// Create a mapping to convert single string relations to edge type triplets:
145-
unordered_map<rel_t, edge_t> to_edge_type;
143+
phmap::flat_hash_map<rel_t, edge_t> to_edge_type;
146144
for (const auto &k : edge_types)
147145
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
148146

149147
// Initialize some data structures for the sampling process:
150-
unordered_map<node_t, vector<int64_t>> samples_dict;
151-
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
152-
unordered_map<node_t, vector<int64_t>> root_time_dict;
148+
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
149+
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
150+
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
153151
for (const auto &node_type : node_types) {
154152
samples_dict[node_type];
155153
to_local_node_dict[node_type];
156154
root_time_dict[node_type];
157155
}
158156

159-
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
157+
phmap::flat_hash_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
160158
for (const auto &kv : colptr_dict) {
161159
const auto &rel_type = kv.key();
162160
rows_dict[rel_type];
@@ -188,7 +186,7 @@ hetero_sample(const vector<node_t> &node_types,
188186
}
189187
}
190188

191-
unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
189+
phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
192190
for (const auto &kv : samples_dict)
193191
slice_dict[kv.first] = {0, kv.second.size()};
194192

@@ -339,7 +337,7 @@ hetero_sample(const vector<node_t> &node_types,
339337
}
340338

341339
if (!directed) { // Construct the subgraph among the sampled nodes:
342-
unordered_map<int64_t, int64_t>::iterator iter;
340+
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
343341
for (const auto &kv : colptr_dict) {
344342
const auto &rel_type = kv.key();
345343
const auto &edge_type = to_edge_type[rel_type];
@@ -455,4 +453,4 @@ hetero_temporal_neighbor_sample_cpu(
455453
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
456454
num_neighbors_dict, node_time_dict, num_hops);
457455
}
458-
}
456+
}

csrc/cpu/utils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "../extensions.h"
4+
#include "parallel_hashmap/phmap.h"
45

56
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
67
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
@@ -27,7 +28,7 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
2728

2829
template <typename key_t, typename scalar_t>
2930
inline c10::Dict<key_t, torch::Tensor>
30-
from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict,
31+
from_vector(const phmap::flat_hash_map<key_t, std::vector<scalar_t>> &vec_dict,
3132
bool inplace = false) {
3233
c10::Dict<key_t, torch::Tensor> out_dict;
3334
for (const auto &kv : vec_dict)
@@ -91,7 +92,7 @@ template <bool replace>
9192
inline void
9293
uniform_choice(const int64_t population, const int64_t num_samples,
9394
const int64_t *idx_data, std::vector<int64_t> *samples,
94-
std::unordered_map<int64_t, int64_t> *to_local_node) {
95+
phmap::flat_hash_map<int64_t, int64_t> *to_local_node) {
9596

9697
if (population == 0 || num_samples == 0)
9798
return;

0 commit comments

Comments
 (0)