@@ -196,7 +196,7 @@ class DecodingCGCache:
196196
197197@torch .inference_mode ()
198198def update_graph_cache (model , cache , batch_size , seqlen_og , max_seqlen , tensor_parallel = 1 ,
199- dtype = None ):
199+ dtype = None , n_warmups = 2 ):
200200 if cache is None :
201201 cache = DecodingCGCache ()
202202 param_example = next (iter (model .parameters ()))
@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
228228 if s_type not in cache .callables :
229229 seqlen = min (max (seqlen_og , seqlen_type_to_seqlen (s_type )), max_seqlen )
230230 cache .callables [s_type ] = capture_graph (
231- model , cache .inference_params , batch_size , seqlen_og , seqlen , mempool = cache .mempool
231+ model , cache .inference_params , batch_size , seqlen_og , seqlen , mempool = cache .mempool ,
232+ n_warmups = n_warmups
232233 )
233234
234235 def dispatch (input_ids , position_ids , seqlen ):
@@ -239,7 +240,8 @@ def dispatch(input_ids, position_ids, seqlen):
239240 return cache
240241
241242
242- def capture_graph (model , inference_params , batch_size , seqlen_og , max_seqlen , mempool = None ):
243+ def capture_graph (model , inference_params , batch_size , seqlen_og , max_seqlen , mempool = None ,
244+ n_warmups = 2 ):
243245 assert max_seqlen >= seqlen_og
244246 device = next (iter (model .parameters ())).device
245247 input_ids = torch .full ((batch_size , 1 ), 0 , dtype = torch .long , device = device )
@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
250252 s = torch .cuda .Stream ()
251253 s .wait_stream (torch .cuda .current_stream ())
252254 with torch .cuda .stream (s ):
253- for _ in range (2 ):
255+ for _ in range (n_warmups ):
254256 logits = model (input_ids , position_ids = position_ids ,
255257 inference_params = inference_params ).logits [:, - 1 ]
256258 s .synchronize ()
259+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
260+ # which requires that graph launch and non-captured launch to not overlap (I think,
261+ # that's how I interpret the documentation). I'm not sure if this is required.
262+ if torch .distributed .is_initialized ():
263+ torch .distributed .barrier ()
257264 torch .cuda .current_stream ().wait_stream (s )
258265 # Captures the graph
259266 # To allow capture, automatically sets a side stream as the current stream in the context
0 commit comments