@@ -102,7 +102,9 @@ void THSetArgErrorHandler( void (*torchArgErrorHandlerFunction_)(int argNumber,
102102static __thread void (* torchGCFunction )(void * data ) = NULL ;
103103static __thread void * torchGCData ;
104104static long heapSize = 0 ;
105- static __thread long heapSoftmax = 300000000 ; // 300MB, adjusted upward dynamically
105+ static __thread long heapDelta = 0 ;
106+ static const long heapMaxDelta = 1e6 ; // limit to +/- 1MB before updating heapSize
107+ static __thread long heapSoftmax = 3e8 ; // 300MB, adjusted upward dynamically
106108static const double heapSoftmaxGrowthThresh = 0.8 ; // grow softmax if >80% max after GC
107109static const double heapSoftmaxGrowthFactor = 1.4 ; // grow softmax by 40%
108110
@@ -134,29 +136,39 @@ static long getAllocSize(void *ptr) {
134136#endif
135137}
136138
139+ static long applyHeapDelta () {
140+ long newHeapSize = THAtomicAddLong (& heapSize , heapDelta ) + heapDelta ;
141+ heapDelta = 0 ;
142+ return newHeapSize ;
143+ }
144+
137145/* (1) if the torch-allocated heap size exceeds the soft max, run GC
138146 * (2) if post-GC heap size exceeds 80% of the soft max, increase the
139147 * soft max by 40%
140148 */
141149static void maybeTriggerGC (long curHeapSize ) {
142- if (torchGCFunction && curHeapSize > heapSoftmax ) {
150+ if (torchGCFunction && curHeapSize > heapSoftmax ) {
143151 torchGCFunction (torchGCData );
144- long newHeapSize = THAtomicGetLong (& heapSize );
145- if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh ) {
152+
153+ // ensure heapSize is accurate before updating heapSoftmax
154+ long newHeapSize = applyHeapDelta ();
155+
156+ if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh ) {
146157 heapSoftmax = heapSoftmax * heapSoftmaxGrowthFactor ;
147158 }
148159 }
149160}
150161
151162// hooks into the TH heap tracking
152163void THHeapUpdate (long size ) {
153- long newHeapSize = THAtomicAddLong ( & heapSize , size ) + size ;
164+ heapDelta += size ;
154165
155- # ifdef TH_CHECK_HEAP_UPDATE
156- if (newHeapSize < 0 ) {
157- THError ( "Torch heap size <0 ?" ) ;
166+ // batch updates to global heapSize to minimize thread contention
167+ if (abs ( heapDelta ) < heapMaxDelta ) {
168+ return ;
158169 }
159- #endif
170+
171+ long newHeapSize = applyHeapDelta ();
160172
161173 if (size > 0 ) {
162174 maybeTriggerGC (newHeapSize );
0 commit comments