Skip to content

Commit bdcae54

Browse files
committed
[LayerNorm] Don't exit early in the backward pass (fix Dao-AILab#781)
1 parent 36bc29e commit bdcae54

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

flash_attn/ops/triton/layer_norm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel(
452452
# Map the program id to the elements of X, DX, and DY it should compute.
453453
row_block_id = tl.program_id(0)
454454
row_start = row_block_id * rows_per_program
455-
if row_start >= M:
456-
return
455+
# Do not early exit if row_start >= M, because we need to write DW and DB
457456
cols = tl.arange(0, BLOCK_N)
458457
mask = cols < N
459458
X += row_start * stride_x_row

0 commit comments

Comments
 (0)