@@ -289,12 +289,12 @@ def forward(
289289 values = self .cache_v [:bsz , : start_pos + seqlen ]
290290
291291 # repeat k/v heads if n_kv_heads < n_heads
292- keys = repeat_kv (keys , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
293- values = repeat_kv (values , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
292+ keys = repeat_kv (keys , self .n_rep ) # (bs, cache_len + seqlen, n_local_heads, head_dim)
293+ values = repeat_kv (values , self .n_rep ) # (bs, cache_len + seqlen, n_local_heads, head_dim)
294294
295295 xq = xq .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
296- keys = keys .transpose (1 , 2 )
297- values = values .transpose (1 , 2 )
296+ keys = keys .transpose (1 , 2 ) # (bs, n_local_heads, cache_len + seqlen, head_dim)
297+ values = values .transpose (1 , 2 ) # (bs, n_local_heads, cache_len + seqlen, head_dim)
298298 scores = torch .matmul (xq , keys .transpose (2 , 3 )) / math .sqrt (self .head_dim )
299299 if mask is not None :
300300 scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
@@ -474,9 +474,19 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
474474 mask = None
475475 if seqlen > 1 :
476476 mask = torch .full (
477- (1 , 1 , seqlen , seqlen ), float ("-inf" ), device = tokens .device
477+ (seqlen , seqlen ), float ("-inf" ), device = tokens .device
478478 )
479- mask = torch .triu (mask , diagonal = start_pos + 1 ).type_as (h )
479+
480+ mask = torch .triu (mask , diagonal = 1 )
481+
482+ # When performing key-value caching, we compute the attention scores
483+ # only for the new sequence. Thus, the matrix of scores is of size
484+ # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
485+ # j > cache_len + i, since row i corresponds to token cache_len + i.
486+ mask = torch .hstack ([
487+ torch .zeros ((seqlen , start_pos ), device = tokens .device ),
488+ mask
489+ ]).type_as (h )
480490
481491 for layer in self .layers :
482492 h = layer (h , start_pos , freqs_cis , mask )
0 commit comments