Skip to content

Commit 5441a69

Browse files
committed
Batch updates to global heapSize
1 parent 9e1b9dd commit 5441a69

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

lib/TH/THGeneral.c

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber,
102102
static __thread void (*torchGCFunction)(void *data) = NULL;
103103
static __thread void *torchGCData;
104104
static long heapSize = 0;
105-
static __thread long heapSoftmax = 300000000; // 300MB, adjusted upward dynamically
105+
static __thread long heapDelta = 0;
106+
static const long heapMaxDelta = 1e6; // limit to +/- 1MB before updating heapSize
107+
static __thread long heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
106108
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
107109
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%
108110

@@ -134,29 +136,39 @@ static long getAllocSize(void *ptr) {
134136
#endif
135137
}
136138

139+
static long applyHeapDelta() {
140+
long newHeapSize = THAtomicAddLong(&heapSize, heapDelta) + heapDelta;
141+
heapDelta = 0;
142+
return newHeapSize;
143+
}
144+
137145
/* (1) if the torch-allocated heap size exceeds the soft max, run GC
138146
* (2) if post-GC heap size exceeds 80% of the soft max, increase the
139147
* soft max by 40%
140148
*/
141149
static void maybeTriggerGC(long curHeapSize) {
142-
if(torchGCFunction && curHeapSize > heapSoftmax ) {
150+
if (torchGCFunction && curHeapSize > heapSoftmax) {
143151
torchGCFunction(torchGCData);
144-
long newHeapSize = THAtomicGetLong(&heapSize);
145-
if(newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh) {
152+
153+
// ensure heapSize is accurate before updating heapSoftmax
154+
long newHeapSize = applyHeapDelta();
155+
156+
if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh) {
146157
heapSoftmax = heapSoftmax * heapSoftmaxGrowthFactor;
147158
}
148159
}
149160
}
150161

151162
// hooks into the TH heap tracking
152163
void THHeapUpdate(long size) {
153-
long newHeapSize = THAtomicAddLong(&heapSize, size) + size;
164+
heapDelta += size;
154165

155-
# ifdef TH_CHECK_HEAP_UPDATE
156-
if (newHeapSize < 0) {
157-
THError("Torch heap size <0 ?");
166+
// batch updates to global heapSize to minimize thread contention
167+
if (abs(heapDelta) < heapMaxDelta) {
168+
return;
158169
}
159-
#endif
170+
171+
long newHeapSize = applyHeapDelta();
160172

161173
if (size > 0) {
162174
maybeTriggerGC(newHeapSize);

0 commit comments

Comments
 (0)