Skip to content

Commit 07f5b21

Browse files
authored
Merge pull request pytorch#702 from gchanan/conservativeAllocator
Improve THCCachingHostAllocator performance by making it reclaim less aggressively
2 parents aeb7a72 + e454870 commit 07f5b21

File tree

4 files changed

+58
-30
lines changed

4 files changed

+58
-30
lines changed

THCCachingHostAllocator.cpp

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <set>
77
#include <stdint.h>
88
#include <unordered_map>
9+
#include <unordered_set>
910
#include <utility>
1011

1112

@@ -23,11 +24,25 @@ struct Block : public BlockSize
2324
{
2425
bool allocated; // true if the block is currently allocated
2526
int event_count; // number of outstanding cuda events
27+
std::unordered_set<THCStream *> streams;
2628

2729
Block(size_t size, void* ptr, bool allocated) :
2830
BlockSize(size, ptr), allocated(allocated), event_count(0) { }
2931
};
3032

33+
struct BlockStreamCleaner {
34+
std::unordered_set<THCStream *> &streams;
35+
36+
BlockStreamCleaner(std::unordered_set<THCStream *> &streams) : streams(streams) {}
37+
~BlockStreamCleaner() {
38+
for(auto it = streams.begin(); it != streams.end(); ++it) {
39+
if (*it != NULL) {
40+
THCStream_free(*it);
41+
}
42+
}
43+
streams.clear();
44+
}
45+
};
3146
static bool BlockComparator(const BlockSize& a, const BlockSize& b)
3247
{
3348
// sort by size, break ties with pointer
@@ -98,21 +113,49 @@ struct HostAllocator
98113
return cudaSuccess;
99114
}
100115

116+
// process outstanding cuda events which may have occurred
117+
cudaError_t err = processEvents();
118+
if (err != cudaSuccess) {
119+
return err;
120+
}
121+
101122
auto it = blocks.find(ptr);
102123
THAssert(it != blocks.end());
103124

104125
Block& block = it->second;
105126
THAssert(block.allocated);
106127

128+
// free (on valid memory) shouldn't fail, so mark unallocated before
129+
// we process the streams.
107130
block.allocated = false;
131+
132+
// since the block has been deallocated, no point in keeping around the
133+
// streams, even in case of error.
134+
BlockStreamCleaner sc(block.streams);
135+
for (auto it = block.streams.begin(); it != block.streams.end(); ++it) {
136+
cudaEvent_t event;
137+
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
138+
if (err != cudaSuccess) {
139+
return err;
140+
}
141+
142+
err = cudaEventRecord(event, (*it) == NULL ? NULL : (*it)->stream);
143+
if (err != cudaSuccess) {
144+
return err;
145+
}
146+
147+
// the block will not be re-used until all associated events have occured
148+
block.event_count++;
149+
cuda_events.emplace_back(event, ptr);
150+
}
108151
if (block.event_count == 0) {
109152
// the block can be re-used if there are no outstanding cuda events
110153
available.insert(block);
111154
}
112155
return cudaSuccess;
113156
}
114157

115-
cudaError_t recordEvent(void* ptr, cudaStream_t stream)
158+
cudaError_t recordEvent(void* ptr, THCStream *stream)
116159
{
117160
std::lock_guard<std::mutex> lock(mutex);
118161
cudaError_t err;
@@ -125,27 +168,11 @@ struct HostAllocator
125168

126169
Block& block = it->second;
127170
THAssert(block.allocated);
128-
129-
// process outstanding cuda events which may have occurred
130-
err = processEvents();
131-
if (err != cudaSuccess) {
132-
return err;
171+
auto res = block.streams.emplace(stream);
172+
if (res.second == true && stream != NULL) {
173+
THCStream_retain(stream);
133174
}
134175

135-
// create and record an event in the given stream
136-
cudaEvent_t event;
137-
err = cudaEventCreateWithFlags(&event, cudaEventDisableTiming);
138-
if (err != cudaSuccess) {
139-
return err;
140-
}
141-
err = cudaEventRecord(event, stream);
142-
if (err != cudaSuccess) {
143-
return err;
144-
}
145-
146-
// the block will not be re-used until all associated events have occured
147-
block.event_count++;
148-
cuda_events.emplace_back(event, ptr);
149176
return cudaSuccess;
150177
}
151178

@@ -186,18 +213,17 @@ struct HostAllocator
186213
std::lock_guard<std::mutex> lock(mutex);
187214

188215
// remove events for freed blocks
189-
std::deque<std::pair<cudaEvent_t, void*>> new_events;
190216
for (auto it = cuda_events.begin(); it != cuda_events.end(); ++it) {
191217
cudaEvent_t event = it->first;
192218
Block& block = blocks.at(it->second);
193219
if (!block.allocated) {
194220
THCudaCheckWarn(cudaEventDestroy(event));
195221
block.event_count--;
196-
} else {
197-
new_events.push_back(*it);
198222
}
199223
}
200-
cuda_events.swap(new_events);
224+
225+
// all cuda_events have been processed
226+
cuda_events.clear();
201227

202228
// clear list of available blocks
203229
available.clear();
@@ -232,7 +258,7 @@ static void THCCachingHostAllocator_free(void* ctx, void* ptr)
232258
allocator.free(ptr);
233259
}
234260

235-
cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, cudaStream_t stream)
261+
cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, THCStream *stream)
236262
{
237263
return allocator.recordEvent(ptr, stream);
238264
}

THCCachingHostAllocator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define THC_CACHING_HOST_ALLOCATOR_INC
33

44
#include "THCGeneral.h"
5+
#include "THCStream.h"
56

67
//
78
// A caching allocator for CUDA host allocations (pinned memory).
@@ -22,7 +23,7 @@ THC_API THAllocator THCCachingHostAllocator;
2223

2324
// Records an event in the specified stream. The allocation 'ptr' will not be
2425
// re-used until the event has occured.
25-
THC_API cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, cudaStream_t stream);
26+
THC_API cudaError_t THCCachingHostAllocator_recordEvent(void *ptr, THCStream *stream);
2627

2728
// Releases cached pinned memory allocations via cudaHostFree
2829
THC_API void THCCachingHostAllocator_emptyCache(void);

THCTensorCopy.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "THCTensor.h"
55
#include "THCGeneral.h"
66
#include "THCHalf.h"
7+
#include "THCStream.h"
78

89
#include "generic/THCTensorCopy.h"
910
#include "THCGenerateAllTypes.h"

generic/THCTensorCopy.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ void THCTensor_(copyAsyncCPU)(THCState *state, THCTensor *self, struct THTensor
118118
THCudaCheck(cudaSetDevice(tensorDevice));
119119
}
120120

121-
cudaStream_t stream = THCState_getCurrentStream(state);
121+
THCStream *stream = THCState_getStream(state);
122122
THCudaCheck(cudaMemcpyAsync(THCTensor_(data)(state, self),
123123
THTensor_(data)(src),
124124
THTensor_(nElement)(src) * sizeof(real),
125125
cudaMemcpyHostToDevice,
126-
stream));
126+
stream == NULL ? NULL : stream->stream));
127127

128128
THCudaCheck(THCCachingHostAllocator_recordEvent(src->storage->data, stream));
129129

@@ -149,12 +149,12 @@ void THTensor_(copyAsyncCuda)(THCState *state, THTensor *self, struct THCTensor
149149
THCudaCheck(cudaSetDevice(tensorDevice));
150150
}
151151

152-
cudaStream_t stream = THCState_getCurrentStream(state);
152+
THCStream *stream = THCState_getStream(state);
153153
THCudaCheck(cudaMemcpyAsync(THTensor_(data)(self),
154154
THCTensor_(data)(state, src),
155155
THCTensor_(nElement)(state, src) * sizeof(real),
156156
cudaMemcpyDeviceToHost,
157-
stream));
157+
stream == NULL ? NULL : stream->stream));
158158

159159
THCudaCheck(THCCachingHostAllocator_recordEvent(src->storage->data, stream));
160160

0 commit comments

Comments
 (0)