Skip to content

Commit 9c3ec7b

Browse files
[TUTORIALS] resolve uninitialized memory access when M< GROUP_SIZE_M in LN tutorial (triton-lang#2837)
👋 There's currently a (minor) issue in the LN tutorial where if M is smaller than GROUP_SIZE_M parts of _dw/_db will contain uninitialised memory which is later summed into the final dw/db.
1 parent 03ceaa6 commit 9c3ec7b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/tutorials/05-layer-norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def backward(ctx, dy):
283283
grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
284284
# accumulate partial sums in separate kernel
285285
_layer_norm_bwd_dwdb[grid](
286-
_dw, _db, dw, db, GROUP_SIZE_M, N, #
286+
_dw, _db, dw, db, min(GROUP_SIZE_M, M), N, #
287287
BLOCK_SIZE_M=32, #
288288
BLOCK_SIZE_N=128, num_ctas=1)
289289
return dx, None, dw, db, None

0 commit comments

Comments
 (0)