66#include < set>
77#include < stdint.h>
88#include < unordered_map>
9+ #include < unordered_set>
910#include < utility>
1011
1112
@@ -23,11 +24,25 @@ struct Block : public BlockSize
2324{
2425 bool allocated; // true if the block is currently allocated
2526 int event_count; // number of outstanding cuda events
27+ std::unordered_set<THCStream *> streams;
2628
2729 Block (size_t size, void * ptr, bool allocated) :
2830 BlockSize (size, ptr), allocated(allocated), event_count(0 ) { }
2931};
3032
33+ struct BlockStreamCleaner {
34+ std::unordered_set<THCStream *> &streams;
35+
36+ BlockStreamCleaner (std::unordered_set<THCStream *> &streams) : streams(streams) {}
37+ ~BlockStreamCleaner () {
38+ for (auto it = streams.begin (); it != streams.end (); ++it) {
39+ if (*it != NULL ) {
40+ THCStream_free (*it);
41+ }
42+ }
43+ streams.clear ();
44+ }
45+ };
3146static bool BlockComparator (const BlockSize& a, const BlockSize& b)
3247{
3348 // sort by size, break ties with pointer
@@ -98,21 +113,49 @@ struct HostAllocator
98113 return cudaSuccess;
99114 }
100115
116+ // process outstanding cuda events which may have occurred
117+ cudaError_t err = processEvents ();
118+ if (err != cudaSuccess) {
119+ return err;
120+ }
121+
101122 auto it = blocks.find (ptr);
102123 THAssert (it != blocks.end ());
103124
104125 Block& block = it->second ;
105126 THAssert (block.allocated );
106127
128+ // free (on valid memory) shouldn't fail, so mark unallocated before
129+ // we process the streams.
107130 block.allocated = false ;
131+
132+ // since the block has been deallocated, no point in keeping around the
133+ // streams, even in case of error.
134+ BlockStreamCleaner sc (block.streams );
135+ for (auto it = block.streams .begin (); it != block.streams .end (); ++it) {
136+ cudaEvent_t event;
137+ err = cudaEventCreateWithFlags (&event, cudaEventDisableTiming);
138+ if (err != cudaSuccess) {
139+ return err;
140+ }
141+
142+ err = cudaEventRecord (event, (*it) == NULL ? NULL : (*it)->stream );
143+ if (err != cudaSuccess) {
144+ return err;
145+ }
146+
147+ // the block will not be re-used until all associated events have occured
148+ block.event_count ++;
149+ cuda_events.emplace_back (event, ptr);
150+ }
108151 if (block.event_count == 0 ) {
109152 // the block can be re-used if there are no outstanding cuda events
110153 available.insert (block);
111154 }
112155 return cudaSuccess;
113156 }
114157
115- cudaError_t recordEvent (void * ptr, cudaStream_t stream)
158+ cudaError_t recordEvent (void * ptr, THCStream * stream)
116159 {
117160 std::lock_guard<std::mutex> lock (mutex);
118161 cudaError_t err;
@@ -125,27 +168,11 @@ struct HostAllocator
125168
126169 Block& block = it->second ;
127170 THAssert (block.allocated );
128-
129- // process outstanding cuda events which may have occurred
130- err = processEvents ();
131- if (err != cudaSuccess) {
132- return err;
171+ auto res = block.streams .emplace (stream);
172+ if (res.second == true && stream != NULL ) {
173+ THCStream_retain (stream);
133174 }
134175
135- // create and record an event in the given stream
136- cudaEvent_t event;
137- err = cudaEventCreateWithFlags (&event, cudaEventDisableTiming);
138- if (err != cudaSuccess) {
139- return err;
140- }
141- err = cudaEventRecord (event, stream);
142- if (err != cudaSuccess) {
143- return err;
144- }
145-
146- // the block will not be re-used until all associated events have occured
147- block.event_count ++;
148- cuda_events.emplace_back (event, ptr);
149176 return cudaSuccess;
150177 }
151178
@@ -186,18 +213,17 @@ struct HostAllocator
186213 std::lock_guard<std::mutex> lock (mutex);
187214
188215 // remove events for freed blocks
189- std::deque<std::pair<cudaEvent_t, void *>> new_events;
190216 for (auto it = cuda_events.begin (); it != cuda_events.end (); ++it) {
191217 cudaEvent_t event = it->first ;
192218 Block& block = blocks.at (it->second );
193219 if (!block.allocated ) {
194220 THCudaCheckWarn (cudaEventDestroy (event));
195221 block.event_count --;
196- } else {
197- new_events.push_back (*it);
198222 }
199223 }
200- cuda_events.swap (new_events);
224+
225+ // all cuda_events have been processed
226+ cuda_events.clear ();
201227
202228 // clear list of available blocks
203229 available.clear ();
@@ -232,7 +258,7 @@ static void THCCachingHostAllocator_free(void* ctx, void* ptr)
232258 allocator.free (ptr);
233259}
234260
235- cudaError_t THCCachingHostAllocator_recordEvent (void *ptr, cudaStream_t stream)
261+ cudaError_t THCCachingHostAllocator_recordEvent (void *ptr, THCStream * stream)
236262{
237263 return allocator.recordEvent (ptr, stream);
238264}
0 commit comments