@@ -60,73 +60,78 @@ __global__ void flash_attn_v1_kernel(const float *Q,
6060
6161 // Define SRAM for Q,K,V,S
6262 extern __shared__ float sram[];
63- int tile_size = Bc * d; // size of Qi, Kj, Vj
63+ const int KV_TILE_SIZE = Bc * d; // size of Kj, Vj
64+ const int Q_TILE_SIZE = Br * d; // size of Qi
65+ // const int S_TILE_SIZE = Br * Bc; // size of Sij = softmax(Qi * Kj^T * softmax_scale)
6466 float *Qi = sram;
65- float *Kj = &sram[tile_size ];
66- float *Vj = &sram[tile_size * 2 ];
67- float *S = &sram[tile_size * 3 ];
67+ float *Kj = &sram[Q_TILE_SIZE ];
68+ float *Vj = &sram[Q_TILE_SIZE + KV_TILE_SIZE ];
69+ float *S = &sram[Q_TILE_SIZE + KV_TILE_SIZE * 2 ];
6870
6971 // outer loop
7072 for (int j = 0 ; j < Tc; j++)
7173 {
7274 // Load Kj, Vj from HBM to SRAM
7375 for (int x = 0 ; x < d; x++)
7476 {
75- Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
76- Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
77+ Kj[(tx * d) + x] = K[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
78+ Vj[(tx * d) + x] = V[qkv_offset + (KV_TILE_SIZE * j) + (tx * d) + x];
7779 }
7880 __syncthreads ();
7981
8082 for (int i = 0 ; i < Tr; i++)
8183 {
82- // Load Qi to SRAM, l and m to registers
83- for (int x = 0 ; x < d; x++)
84+ if (tx < Br)
8485 {
85- Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
86- }
87- float row_m_prev = m[lm_offset + (Br * i) + tx];
88- float row_l_prev = l[lm_offset + (Br * i) + tx];
89-
90- // S = QK^T, row_m = rowmax(S)
91- float row_m = -INFINITY;
92- for (int y = 0 ; y < Bc; y++)
93- {
94- float sum = 0 ;
86+ // Load Qi to SRAM, l and m to registers
9587 for (int x = 0 ; x < d; x++)
9688 {
97- sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
89+ Qi[(tx * d) + x] = Q[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x];
9890 }
99- sum *= softmax_scale;
100- S[(Bc * tx) + y] = sum;
101-
102- if (sum > row_m)
103- row_m = sum;
104- }
91+ float row_m_prev = m[lm_offset + (Br * i) + tx];
92+ float row_l_prev = l[lm_offset + (Br * i) + tx];
10593
106- // P = exp(S - row_m), row_l = rowsum(P)
107- float row_l = 0 ;
108- for (int y = 0 ; y < Bc; y++)
109- {
110- S[(Bc * tx) + y] = __expf (S[(Bc * tx) + y] - row_m);
111- row_l += S[(Bc * tx) + y];
112- }
94+ // S = QK^T, row_m = rowmax(S)
95+ float row_m = -INFINITY;
96+ for (int y = 0 ; y < Bc; y++)
97+ {
98+ float sum = 0 ;
99+ for (int x = 0 ; x < d; x++)
100+ {
101+ sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
102+ }
103+ sum *= softmax_scale;
104+ S[(Bc * tx) + y] = sum;
113105
114- // Compute new m and l
115- float row_m_new = max (row_m_prev, row_m) ;
116- float row_l_new = ( __expf (row_m_prev - row_m_new) * row_l_prev) + ( __expf (row_m - row_m_new) * row_l);
106+ if (sum > row_m)
107+ row_m = sum ;
108+ }
117109
118- // Write O, l, m to HBM
119- for (int x = 0 ; x < d; x++)
120- {
121- float pv = 0 ; // Pij * Vj
110+ // P = exp(S - row_m), row_l = rowsum(P)
111+ float row_l = 0 ;
122112 for (int y = 0 ; y < Bc; y++)
123113 {
124- pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
114+ S[(Bc * tx) + y] = __expf (S[(Bc * tx) + y] - row_m);
115+ row_l += S[(Bc * tx) + y];
116+ }
117+
118+ // Compute new m and l
119+ float row_m_new = max (row_m_prev, row_m);
120+ float row_l_new = (__expf (row_m_prev - row_m_new) * row_l_prev) + (__expf (row_m - row_m_new) * row_l);
121+
122+ // Write O, l, m to HBM
123+ for (int x = 0 ; x < d; x++)
124+ {
125+ float pv = 0 ; // Pij * Vj
126+ for (int y = 0 ; y < Bc; y++)
127+ {
128+ pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
129+ }
130+ O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf (row_m_prev - row_m_new) * O[qkv_offset + (Q_TILE_SIZE * i) + (tx * d) + x]) + (__expf (row_m - row_m_new) * pv));
125131 }
126- O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf (row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) + (__expf (row_m - row_m_new) * pv));
132+ m[lm_offset + (Br * i) + tx] = row_m_new;
133+ l[lm_offset + (Br * i) + tx] = row_l_new;
127134 }
128- m[lm_offset + (Br * i) + tx] = row_m_new;
129- l[lm_offset + (Br * i) + tx] = row_l_new;
130135 }
131136 __syncthreads ();
132137 }
@@ -234,7 +239,8 @@ int main()
234239
235240 // split kv seq_len to Tc and Q seq_len to Tr
236241 const int Bc = 32 ;
237- const int Br = 32 ;
242+ // const int Br = 32;
243+ const int Br = 16 ;
238244 const int Tc = ceil ((float )N / Bc);
239245 const int Tr = ceil ((float )N / Br);
240246
@@ -305,7 +311,7 @@ int main()
305311
306312 if (max_diff < 0.0001 )
307313 {
308- printf (" Results are correct! " );
314+ printf (" Results are correct! \n " );
309315 }
310316 else
311317 {
0 commit comments