Trace for forward
  1. (Function.inline [ f; cCall "forward" ]);
  2. (Loop.tile ~bound:TileBoundMin (trm_int chunk_len) [ f; cFor "i" ]);
  3. Loop.hoist [nbMulti;cFunDef "generate_prompt_proc"; cVarDefs [ "embedding"; "mha_norm"; "mha_q"; "mha_score"; "mha_att"; "mha_blend"; "mha_out"; "ffn_norm"; "ffn_up"; "ffn_fc"; "ffn_out" ]; ];
  4. Loop.fission [ f; cForBody "i"; tBetweenAll ];
  5. Loop.reorder_at ~order:[ "l"; "i" ] [ f; cForBody "l"; dSeqNth 0 ];
  6. Loop.fission [ f; cForBody "l"; cForBody "i"; tBetweenAll ];
  7. Loop.reorder_at ~order:[ "q"; "i" ] [ nbMulti; f; cForBody "q"; dSeqNth 0 ];
  8. Loop.reorder_at ~order:[ "h"; "i" ] [ nbMulti; f; cForBody "h"; dSeqNth 0 ];
  9. Matrix.reorder_dims ~order:[ 1; 0; 2 ] [ nbMulti; f; cVarDefs [ "mha_q"; "mha_score"; "mha_blend" ] ];
  10. Matrix.tile ~block_size:q_head_per_kv_head_count ~nb_blocks:kv_headcount ~index_dim:0 [ nbMulti; f; cVarDefs [ "mha_q"; "mha_score"; "mha_blend" ] ];
  11. Loop_basic.grid_enumerate [ ("h2", kv_headcount); ("q2", q_head_per_kv_head_count) ] [ nbMulti; f; cFor "q" ];
  12. Variable.inline [ nbMulti; f; cVarDef "q" ];
  13. Rewrite.equiv_at "int i; int j ; int k; ==> (i*j +k) /j == i" ~indepth:true [ f ];
  14. Rewrite.equiv_at "int i; int j ; int k; ==> (i*j +k) %j == k" ~indepth:true [ f ];
  15. Function.inline [ nbMulti; f; cCall "matvec" ];
  16. Matrix.simpl_access_of_access ~indepth:true [ f ];
  17. Matrix.simpl_index_add [ nbMulti; f; cCellAccess ~base:[ cVar ~substr:true "" ] (); cBinop ~lhs:[ cCall ~regexp:true "MINDEX." ] Binop_add ];
  18. Rewrite.equiv_at "int j; ==> 0 + j == j" [ nbMulti; f ] ~indepth:true;
  19. Function.uninline ~f:[ cFunDef "matmul" ] [ occIndices [ 0; 3; 4; 5 ]; f; cFor "i" ~body:[ cSeq ~instrs_pred:(target_list_one_st [ cFor "j" ]) () ] ])

    
tmp/{before6fc5cf.cpp → after016d03.cpp} RENAMED
@@ -184,20 +184,238 @@ void generate_prompt_proc(int vocabulary_len, int context_len, int layer_count,
184
  int q_head_per_kv_head_count, int embedding_dim,
185
  int head_dim, int q_dim, int kv_dim, int hidden_dim,
186
  float epsilon, float* embedding_weight,
187
  float* mha_norm_weight, float* mha_q_weight,
188
  float* mha_k_weight, float* mha_v_weight,
189
  float* mha_out_weight, float* ffn_norm_weight,
190
  float* ffn_fc_weight, float* ffn_up_weight,
191
  float* ffn_out_weight, float* out_norm_weight,
192
  float* out_weight, float* k_cache, float* v_cache,
193
  float* logits, int* sequence, int sequence_len) {
194
- for (int i = 0; i < sequence_len; i++) {
 
 
 
 
 
 
 
195
- forward(sequence[i], vocabulary_len, context_len, layer_count, q_head_count,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
- kv_head_count, q_head_per_kv_head_count, embedding_dim, head_dim,
 
 
 
 
 
197
- q_dim, kv_dim, hidden_dim, epsilon, embedding_weight,
198
- mha_norm_weight, mha_q_weight, mha_k_weight, mha_v_weight,
 
 
 
 
 
 
 
 
 
 
199
- mha_out_weight, ffn_norm_weight, ffn_fc_weight, ffn_up_weight,
 
 
 
 
 
 
 
 
 
 
 
 
200
- ffn_out_weight, out_norm_weight, out_weight, k_cache, v_cache,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
- logits, i, i == sequence_len - 1 ? 1 : 0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  }
203
  }
184
  int q_head_per_kv_head_count, int embedding_dim,
185
  int head_dim, int q_dim, int kv_dim, int hidden_dim,
186
  float epsilon, float* embedding_weight,
187
  float* mha_norm_weight, float* mha_q_weight,
188
  float* mha_k_weight, float* mha_v_weight,
189
  float* mha_out_weight, float* ffn_norm_weight,
190
  float* ffn_fc_weight, float* ffn_up_weight,
191
  float* ffn_out_weight, float* out_norm_weight,
192
  float* out_weight, float* k_cache, float* v_cache,
193
  float* logits, int* sequence, int sequence_len) {
194
+ for (int bi = 0; bi < sequence_len; bi += 512) {
195
+ float* const embedding =
196
+ (float*)malloc(MSIZE2(min(sequence_len, bi + 512) - bi, embedding_dim) *
197
+ sizeof(float));
198
+ float* const mha_norm =
199
+ (float*)malloc(MSIZE2(min(sequence_len, bi + 512) - bi, embedding_dim) *
200
+ sizeof(float));
201
+ float* const mha_q =
202
+ (float*)malloc(MSIZE4(kv_head_count, q_head_per_kv_head_count,
203
+ min(sequence_len, bi + 512) - bi, head_dim) *
204
+ sizeof(float));
205
+ __ghost([&]() {
206
+ __consumes(
207
+ "for i1 in 0..q_head_count -> for i2 in 0..(min(sequence_len, bi + "
208
+ "512) - bi) -> for i3 in 0..head_dim -> &mha_q[i1 / "
209
+ "q_head_per_kv_head_count][i1 % q_head_per_kv_head_count][i2][i3] ~> "
210
+ "UninitCell");
211
+ __produces(
212
+ "for i1 in 0..(min(sequence_len, bi + 512) - bi) -> for i2 in "
213
+ "0..q_head_count -> for i3 in 0..head_dim -> &mha_q[i1 / "
214
+ "q_head_per_kv_head_count][i1 % q_head_per_kv_head_count][i2][i3] ~> "
215
+ "UninitCell");
216
+ __admitted();
217
+ __with("justif := reorder_groups");
218
+ });
219
+ float* const mha_score =
220
+ (float*)malloc(MSIZE4(kv_head_count, q_head_per_kv_head_count,
221
+ min(sequence_len, bi + 512) - bi, context_len) *
222
+ sizeof(float));
223
+ __ghost([&]() {
224
+ __consumes(
225
+ "for i1 in 0..q_head_count -> for i2 in 0..(min(sequence_len, bi + "
226
+ "512) - bi) -> for i3 in 0..context_len -> &mha_score[i1 / "
227
+ "q_head_per_kv_head_count][i1 % q_head_per_kv_head_count][i2][i3] ~> "
228
+ "UninitCell");
229
+ __produces(
230
+ "for i1 in 0..(min(sequence_len, bi + 512) - bi) -> for i2 in "
231
+ "0..q_head_count -> for i3 in 0..context_len -> &mha_score[i1 / "
232
+ "q_head_per_kv_head_count][i1 % q_head_per_kv_head_count][i2][i3] ~> "
233
+ "UninitCell");
234
+ __admitted();
235
+ __with("justif := reorder_groups");
236
+ });
237
+ float* const mha_blend =
238
+ (float*)malloc(MSIZE4(kv_head_count, q_head_per_kv_head_count,
239
+ min(sequence_len, bi + 512) - bi, head_dim) *
240
+ sizeof(float));
241
+ __ghost([&]() {
242
+ __consumes(
243
+ "for i1 in 0..q_head_count -> for i2 in 0..(min(sequence_len, bi + "
244
+ "512) - bi) -> for i3 in 0..head_dim -> &mha_blend[i1 / "
245
+ "q_head_per_kv_head_count][i1 % q_head_per_kv_head_count][i2][i3] ~> "
246
+ "UninitCell");
247
+ __produces(
248
+ "for i1 in 0..(min(sequence_len, bi + 512) - bi) -> for i2 in "
249
+ "0..q_head_count -> for i3 in 0..head_dim -> &mha_blend[i1 / "
250
+ "q_head_per_kv_head_count][i1 % q_head_per_kv_head_count][i2][i3] ~> "
251
+ "UninitCell");
252
+ __admitted();
253
+ __with("justif := reorder_groups");
254
+ });
255
+ float* const mha_att =
256
+ (float*)malloc(MSIZE2(min(sequence_len, bi + 512) - bi, embedding_dim) *
257
+ sizeof(float));
258
+ float* const mha_out =
259
+ (float*)malloc(MSIZE2(min(sequence_len, bi + 512) - bi, embedding_dim) *
260
+ sizeof(float));
261
+ float* const ffn_norm =
262
+ (float*)malloc(MSIZE2(min(sequence_len, bi + 512) - bi, embedding_dim) *
263
+ sizeof(float));
264
+ float* const ffn_fc = (float*)malloc(
265
+ MSIZE2(min(sequence_len, bi + 512) - bi, hidden_dim) * sizeof(float));
266
+ float* const ffn_up = (float*)malloc(
267
+ MSIZE2(min(sequence_len, bi + 512) - bi, hidden_dim) * sizeof(float));
268
+ float* const ffn_out =
269
+ (float*)malloc(MSIZE2(min(sequence_len, bi + 512) - bi, embedding_dim) *
270
+ sizeof(float));
271
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
272
+ __ghost(assume, "P := in_range(i + bi, bi..min(sequence_len, bi + 512))");
273
+ }
274
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
275
+ for (int e = 0; e < embedding_dim; e++) {
276
+ embedding[i][e] = embedding_weight[sequence[i + bi]][e];
277
+ }
278
+ }
279
+ for (int l = 0; l < layer_count; l++) {
280
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
281
+ rmsnorm(embedding_dim, &mha_norm[i][0], &embedding[i][0],
282
+ &mha_norm_weight[l][0], epsilon);
283
+ }
284
+ for (int h2 = 0; h2 < kv_head_count; h2++) {
285
+ for (int q2 = 0; q2 < q_head_per_kv_head_count; q2++) {
286
+ matmul(min(sequence_len, bi + 512) - bi, head_dim, embedding_dim,
287
+ &mha_q[h2][q2][0][0], mha_norm,
288
+ &mha_q_weight[l][h2 * q_head_per_kv_head_count + q2][0][0]);
289
+ }
290
+ }
291
+ for (int h = 0; h < kv_head_count; h++) {
292
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
293
+ for (int j = 0; j < head_dim; j++) {
294
+ k_cache[l][h][i + bi][j] = 0.f;
295
+ for (int k = 0; k < embedding_dim; k++) {
296
+ k_cache[l][h][i + bi][j] +=
297
+ mha_norm[i][k] * mha_k_weight[l][h][j][k];
298
+ }
299
+ }
300
+ }
301
+ }
302
+ for (int h = 0; h < kv_head_count; h++) {
303
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
304
+ for (int j = 0; j < head_dim; j++) {
305
+ v_cache[l][h][i + bi][j] = 0.f;
306
+ for (int k = 0; k < embedding_dim; k++) {
307
+ v_cache[l][h][i + bi][j] +=
308
+ mha_norm[i][k] * mha_v_weight[l][h][j][k];
309
+ }
310
+ }
311
+ }
312
+ }
313
+ for (int h2 = 0; h2 < kv_head_count; h2++) {
314
+ for (int q2 = 0; q2 < q_head_per_kv_head_count; q2++) {
315
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
316
+ rope(head_dim, &mha_q[h2][q2][i][0], i + bi);
317
+ }
318
+ }
319
+ }
320
+ for (int h = 0; h < kv_head_count; h++) {
321
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
322
+ rope(head_dim, &k_cache[l][h][i + bi][0], i + bi);
323
+ }
324
+ }
325
+ for (int h2 = 0; h2 < kv_head_count; h2++) {
326
+ for (int q2 = 0; q2 < q_head_per_kv_head_count; q2++) {
327
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
328
+ int h = h2;
329
+ for (int p = 0; p <= i + bi; p++) {
330
+ mha_score[h2][q2][i][p] = 0.f;
331
+ for (int e = 0; e < head_dim; e++) {
332
+ mha_score[h2][q2][i][p] +=
333
+ mha_q[h2][q2][i][e] * k_cache[l][h][p][e];
334
+ }
335
+ mha_score[h2][q2][i][p] /= sqrtf(head_dim);
336
+ }
337
+ softmax(i + bi + 1, context_len, &mha_score[h2][q2][i][0]);
338
+ for (int e = 0; e < head_dim; e++) {
339
+ mha_blend[h2][q2][i][e] = 0.f;
340
+ }
341
+ for (int p = 0; p <= i + bi; p++) {
342
+ for (int e = 0; e < head_dim; e++) {
343
+ mha_blend[h2][q2][i][e] +=
344
+ mha_score[h2][q2][i][p] * v_cache[l][h][p][e];
345
+ }
346
+ }
347
+ }
348
+ }
349
+ }
350
+ for (int h2 = 0; h2 < kv_head_count; h2++) {
351
+ for (int q2 = 0; q2 < q_head_per_kv_head_count; q2++) {
352
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
353
+ for (int e = 0; e < head_dim; e++) {
354
+ mha_att[i][(h2 * q_head_per_kv_head_count + q2) * head_dim + e] =
355
+ mha_blend[h2][q2][i][e];
356
+ }
357
+ }
358
+ }
359
+ }
360
+ matmul(min(sequence_len, bi + 512) - bi, embedding_dim, embedding_dim,
361
+ mha_out, mha_att, &mha_out_weight[l][0][0]);
362
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
363
+ for (int e = 0; e < embedding_dim; e++) {
364
+ embedding[i][e] += mha_out[i][e];
365
+ }
366
+ }
367
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
368
+ rmsnorm(embedding_dim, &ffn_norm[i][0], &embedding[i][0],
369
+ &ffn_norm_weight[l][0], epsilon);
370
+ }
371
+ matmul(min(sequence_len, bi + 512) - bi, hidden_dim, embedding_dim,
372
+ ffn_fc, ffn_norm, &ffn_fc_weight[l][0][0]);
373
+ matmul(min(sequence_len, bi + 512) - bi, hidden_dim, embedding_dim,
374
+ ffn_up, ffn_norm, &ffn_up_weight[l][0][0]);
375
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
376
+ for (int e = 0; e < hidden_dim; e++) {
377
+ ffn_fc[i][e] *= 1.f / (1.f + expf(-ffn_fc[i][e]));
378
+ ffn_fc[i][e] *= ffn_up[i][e];
379
+ }
380
+ }
381
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
382
+ for (int j = 0; j < embedding_dim; j++) {
383
+ ffn_out[i][j] = 0.f;
384
+ for (int k = 0; k < hidden_dim; k++) {
385
+ ffn_out[i][j] += ffn_fc[i][k] * ffn_out_weight[l][j][k];
386
+ }
387
+ }
388
+ }
389
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
390
+ for (int e = 0; e < embedding_dim; e++) {
391
+ embedding[i][e] += ffn_out[i][e];
392
+ }
393
+ }
394
+ }
395
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
396
+ rmsnorm(embedding_dim, &embedding[i][0], &embedding[i][0],
397
+ out_norm_weight, epsilon);
398
+ }
399
+ for (int i = 0; i < min(sequence_len, bi + 512) - bi; i++) {
400
+ if (i + bi == sequence_len - 1 ? 1 : 0) {
401
+ for (int j = 0; j < vocabulary_len; j++) {
402
+ logits[j] = 0.f;
403
+ for (int k = 0; k < embedding_dim; k++) {
404
+ logits[j] += embedding[i][k] * out_weight[j][k];
405
+ }
406
+ }
407
+ }
408
+ }
409
+ free(ffn_out);
410
+ free(ffn_up);
411
+ free(ffn_fc);
412
+ free(ffn_norm);
413
+ free(mha_out);
414
+ free(mha_att);
415
+ free(mha_blend);
416
+ free(mha_score);
417
+ free(mha_q);
418
+ free(mha_norm);
419
+ free(embedding);
420
  }
421
  }
 Show: