@@ -34,34 +34,37 @@ def cross_entropy_fwd_kernel(
3434 total_classes ,
3535 class_start_idx , # Useful for tensor parallel when each rank only has a subset of classes
3636 n_cols , # shapes
37- n_rows ,
3837 logits_row_stride , # strides
3938 BLOCK_SIZE : tl .constexpr ,
4039 HAS_SMOOTHING : tl .constexpr ,
4140 # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
4241 SPLIT : tl .constexpr ,
4342):
4443 row_idx = tl .program_id (0 )
45- col_block_idx = tl .program_id (1 )
4644 logits_ptr = logits_ptr + row_idx * logits_row_stride .to (tl .int64 )
47- col_offsets = col_block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
45+ sum_logits = 0.0 # For smoothing
46+ # Statistics for online softmax
47+ m_i = - float ("inf" )
48+ l_i = 0.0
49+ for col_offset in range (0 , n_cols , BLOCK_SIZE ):
50+ cols = col_offset + tl .arange (0 , BLOCK_SIZE )
51+ logits = tl .load (logits_ptr + cols , mask = cols < n_cols , other = - float ("inf" )).to (
52+ tl .float32
53+ ) * logit_scale
54+ if HAS_SMOOTHING :
55+ sum_logits += tl .sum (tl .where (cols < n_cols , logits , 0.0 ))
56+ m_i_new = tl .maximum (m_i , tl .max (logits ))
57+ l_i = tl .exp (m_i - m_i_new ) * l_i + tl .sum (tl .exp (logits - m_i_new ))
58+ m_i = m_i_new
59+ lse = tl .log (l_i ) + m_i
60+ tl .store (lse_ptr + row_idx , lse )
4861 label_idx = tl .load (labels_ptr + row_idx )
49- logits = tl .load (logits_ptr + col_offsets , mask = col_offsets < n_cols , other = - float ("inf" )).to (
50- tl .float32
51- ) * logit_scale
52- max_logits = tl .max (logits , 0 )
53- if HAS_SMOOTHING :
54- sum_logits = tl .sum (tl .where (col_offsets < n_cols , logits , 0.0 ), 0 )
55- lse = tl .log (tl .sum (tl .exp (logits - max_logits ), 0 )) + max_logits
56- tl .store (lse_ptr + col_block_idx * n_rows + row_idx , lse )
5762 if label_idx == ignore_index :
5863 loss = 0.0
5964 z_loss = 0.0
6065 else :
6166 label_idx -= class_start_idx
62- if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min (
63- n_cols , (col_block_idx + 1 ) * BLOCK_SIZE
64- ):
67+ if label_idx >= 0 and label_idx < n_cols :
6568 logits_label = tl .load (logits_ptr + label_idx ) * logit_scale
6669 if HAS_SMOOTHING :
6770 loss = (
@@ -82,9 +85,9 @@ def cross_entropy_fwd_kernel(
8285 loss += z_loss
8386 else :
8487 z_loss = 0.0
85- tl .store (loss_ptr + col_block_idx * n_rows + row_idx , loss )
88+ tl .store (loss_ptr + row_idx , loss )
8689 if not SPLIT :
87- tl .store (z_loss_ptr + col_block_idx * n_rows + row_idx , z_loss )
90+ tl .store (z_loss_ptr + row_idx , z_loss )
8891
8992
9093@triton .heuristics (
@@ -161,27 +164,20 @@ def forward(
161164
162165 if logits .stride (- 1 ) != 1 :
163166 logits = logits .contiguous ()
164- # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
165- MAX_BLOCK_SIZE = 64 * 1024
167+ MAX_BLOCK_SIZE = 16 * 1024
166168 BLOCK_SIZE = min (triton .next_power_of_2 (n_cols ), MAX_BLOCK_SIZE )
167169 num_warps = (
168170 4
169171 if BLOCK_SIZE < 2048
170172 else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32 ))
171173 )
172- # We may split the lse computation across multiple blocks, then do a reduction
173- # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
174- # where having just one thread block processing more than 64k elements is slow.
175- split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
176- n_splits = (n_cols + BLOCK_SIZE - 1 ) // BLOCK_SIZE
177- loss_shape = (n_splits , n_rows ) if n_splits > 1 else (n_rows ,)
178- losses = torch .empty (* loss_shape , dtype = torch .float , device = logits .device )
179- lse = torch .empty (* loss_shape , dtype = torch .float , device = logits .device )
180- z_losses = torch .empty (* loss_shape , dtype = torch .float , device = logits .device )
174+ losses = torch .empty (n_rows , dtype = torch .float , device = logits .device )
175+ lse = torch .empty (n_rows , dtype = torch .float , device = logits .device )
176+ z_losses = torch .empty (n_rows , dtype = torch .float , device = logits .device )
181177 # Need this, otherwise Triton tries to launch from cuda:0 and we get
182178 # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
183179 with torch .cuda .device (logits .device .index ):
184- cross_entropy_fwd_kernel [(n_rows , n_splits )](
180+ cross_entropy_fwd_kernel [(n_rows ,)](
185181 losses , # data ptrs
186182 lse ,
187183 z_losses ,
@@ -194,23 +190,19 @@ def forward(
194190 total_classes ,
195191 class_start_idx ,
196192 n_cols , # shapes
197- n_rows ,
198193 logits .stride (0 ), # strides
199194 BLOCK_SIZE = BLOCK_SIZE , # constants
200195 num_warps = num_warps ,
201- SPLIT = split ,
196+ SPLIT = world_size > 1 ,
202197 )
203198
204- if split :
199+ if world_size > 1 :
205200 # If there's no smoothing, if labels are in the vocab of this partition, losses contains
206201 # - predicted logit, and 0 otherwise.
207202 # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
208203 # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
209204 # For labels not in the vocab of this partition, losses contains
210205 # -0.1 * sum logit / total_classes.
211- if n_splits > 1 :
212- lse = torch .logsumexp (lse , dim = 0 )
213- losses = losses .sum (dim = 0 )
214206 if world_size > 1 :
215207 lse_allgather = torch .empty (world_size , n_rows , dtype = lse .dtype , device = lse .device )
216208 torch .distributed .all_gather_into_tensor (lse_allgather , lse , group = process_group )
@@ -243,6 +235,7 @@ def forward(
243235 ctx .class_start_idx = class_start_idx
244236 ctx .inplace_backward = inplace_backward
245237
238+
246239 return losses , z_losses
247240
248241 @staticmethod
0 commit comments