We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 36bc29e commit bdcae54Copy full SHA for bdcae54
flash_attn/ops/triton/layer_norm.py
@@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel(
452
# Map the program id to the elements of X, DX, and DY it should compute.
453
row_block_id = tl.program_id(0)
454
row_start = row_block_id * rows_per_program
455
- if row_start >= M:
456
- return
+ # Do not early exit if row_start >= M, because we need to write DW and DB
457
cols = tl.arange(0, BLOCK_N)
458
mask = cols < N
459
X += row_start * stride_x_row
0 commit comments