Trace for demo_verif
  1. Preprocessing contracts
  2. (Function.inline [ f; cCall "forward" ]);
  3. Loop.hoist [ nbMulti; cFunDef "generate_prompt_proc"; cVarDefs [ "embedding"; "mha_norm"; "mha_q" ] ];
  4. Loop.fission [ f; cForBody "i"; cFor "l"; tBefore ];
  5. Loop.reorder_at ~order:[ "l"; "i" ] [ f; cForBody "l"; dSeqNth 0 ])

    
tmp/{beforee69392.cpp → afterbdc661.cpp} RENAMED
@@ -21,20 +21,21 @@ float _sinf(float x) {
21
 
22
  float _sqrtf(float x) {
23
  __pure();
24
  __admitted();
25
  return sqrtf(x);
26
  }
27
 
28
  void rope(int col_count, float* x, int pos) {
29
  __modifies("x ~> Matrix1(col_count)");
30
  for (int j = 0; j < col_count; j += 2) {
 
31
  __smodifies("x ~> Matrix1(col_count)");
32
  __ghost(assume, "P := is_subrange(j..(j + 1), 0..col_count)");
33
  const __ghost_fn focus_subrange =
34
  __ghost_begin(group_focus_subrange, "sub_range := j..(j + 1)");
35
  float freq = 1.f / _powf(500000.f, (float)j / (float)col_count);
36
  float val = (float)pos * 2.f;
37
  float fcr = _cosf(val);
38
  float fci = _sinf(val);
39
  __ghost([&]() {
40
  __consumes("for i1 in j..(j + 1) -> &x[MINDEX1(col_count, i1)] ~> Cell");
@@ -64,68 +65,88 @@ float _expf(float x) {
64
 
65
  void softmax(int col_count, int col_stride, float* x) {
66
  __modifies("x ~> Matrix1(col_count)");
67
  __ghost(assume, "P := in_range(0, 0..col_count)");
68
  const __ghost_fn max = __ghost_begin(ro_matrix1_focus, "matrix := x, i := 0");
69
  float max_val = x[MINDEX1(col_count, 0)];
70
  __ghost_end(max);
71
  const __ghost_fn focus_subrange =
72
  __ghost_begin(group_focus_subrange, "sub_range := 1..col_count");
73
  for (int j = 1; j < col_count; j++) {
 
 
 
 
 
74
  __xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
75
  if (x[MINDEX1(col_count, j)] > max_val) {
76
  max_val = x[MINDEX1(col_count, j)];
77
  }
78
  }
79
  __ghost_end(focus_subrange);
80
  float sum = 0.f;
81
  for (int j = 0; j < col_count; j++) {
 
 
 
82
  __xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
83
  x[MINDEX1(col_count, j)] = _expf(x[MINDEX1(col_count, j)] - max_val);
84
  sum += x[MINDEX1(col_count, j)];
85
  }
86
  for (int j = 0; j < col_count; j++) {
 
 
87
  __xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
88
  x[MINDEX1(col_count, j)] /= sum;
89
  }
90
  }
91
 
92
  void rmsnorm(int col_count, float* y, float* x, float* w, float epsilon) {
93
  __writes("y ~> Matrix1(col_count)");
94
  __reads("x ~> Matrix1(col_count)");
95
  __reads("w ~> Matrix1(col_count)");
96
  float ss = 0.f;
97
  for (int j = 0; j < col_count; j++) {
 
98
  __smodifies("&ss ~> Cell");
99
  __xreads("&x[MINDEX1(col_count, j)] ~> Cell");
100
  ss += x[MINDEX1(col_count, j)] * x[MINDEX1(col_count, j)];
101
  }
102
  ss /= (float)col_count;
103
  ss += epsilon;
104
  ss = 1.f / _sqrtf(ss);
105
  for (int j = 0; j < col_count; j++) {
 
 
106
  __xwrites("&y[MINDEX1(col_count, j)] ~> Cell");
107
  __xreads("&x[MINDEX1(col_count, j)] ~> Cell");
108
  __xreads("&w[MINDEX1(col_count, j)] ~> Cell");
109
  y[MINDEX1(col_count, j)] =
110
  w[MINDEX1(col_count, j)] * (ss * x[MINDEX1(col_count, j)]);
111
  }
112
  }
113
 
114
  void matvec(int col_count, int red_count, float* x, float* y, float* w) {
115
  __writes("x ~> Matrix1(col_count)");
116
  __reads("y ~> Matrix1(red_count)");
117
  __reads("w ~> Matrix2(col_count, red_count)");
118
  for (int j = 0; j < col_count; j++) {
 
 
 
119
  __xwrites("&x[MINDEX1(col_count, j)] ~> Cell");
120
  x[MINDEX1(col_count, j)] = 0.f;
121
  for (int k = 0; k < red_count; k++) {
 
 
 
 
122
  const __ghost_fn focusy =
123
  __ghost_begin(ro_matrix1_focus, "matrix := y, i := k");
124
  const __ghost_fn focusw =
125
  __ghost_begin(ro_matrix2_focus, "matrix := w, i := j, j := k");
126
  x[MINDEX1(col_count, j)] +=
127
  y[MINDEX1(red_count, k)] * w[MINDEX2(col_count, red_count, j, k)];
128
  __ghost_end(focusy);
129
  __ghost_end(focusw);
130
  }
131
  }
@@ -145,46 +166,54 @@ void forward(int token, int vocabulary_len, int context_len, int layer_count,
145
  __reads(
146
  "mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
147
  "embedding_dim)");
148
  float* const embedding =
149
  (float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
150
  float* const mha_norm = (float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
151
  float* const mha_q =
152
  (float*)malloc(MSIZE2(q_head_count, head_dim) * sizeof(float));
153
  __ghost(assume, "P := in_range(token, 0..vocabulary_len)");
154
  for (int e = 0; e < embedding_dim; e++) {
 
 
155
  __xwrites("&embedding[MINDEX1(embedding_dim, e)] ~> Cell");
156
  const __ghost_fn focus_embedding_weight = __ghost_begin(
157
  ro_matrix2_focus, "matrix := embedding_weight, i := token, j := e");
158
  embedding[MINDEX1(embedding_dim, e)] =
159
  embedding_weight[MINDEX2(vocabulary_len, embedding_dim, token, e)];
160
  __ghost_end(focus_embedding_weight);
161
  }
162
  __ghost([&]() {
163
  __consumes("embedding ~> Matrix1(embedding_dim)");
164
  __consumes("mha_norm ~> UninitMatrix1(embedding_dim)");
165
  __produces("&embedding[MINDEX0()] ~> Matrix1(embedding_dim)");
166
  __produces("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
167
  __admitted();
168
  });
169
  for (int l = 0; l < layer_count; l++) {
 
 
 
 
170
  __xreads(
171
  "for i1 in 0..embedding_dim -> &mha_norm_weight[MINDEX2(layer_count, "
172
  "embedding_dim, l, i1)] ~> Cell");
173
  __xreads(
174
  "for q in 0..q_head_count -> for h in 0..head_dim -> for e in "
175
  "0..embedding_dim -> &mha_q_weight[MINDEX4(layer_count, q_head_count, "
176
  "head_dim, embedding_dim, l, q, h, e)] ~> Cell");
177
  rmsnorm(embedding_dim, &mha_norm[MINDEX0()], &embedding[MINDEX0()],
178
  &mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, 0)],
179
  epsilon);
180
  for (int q = 0; q < q_head_count; q++) {
 
 
181
  __xmodifies(
182
  "for i1 in 0..head_dim -> &mha_q[MINDEX2(q_head_count, head_dim, q, "
183
  "i1)] ~> UninitCell");
184
  __xreads(
185
  "for h in 0..head_dim -> for e in 0..embedding_dim -> "
186
  "&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
187
  "embedding_dim, l, q, h, e)] ~> Cell");
188
  matvec(head_dim, embedding_dim,
189
  &mha_q[MINDEX2(q_head_count, head_dim, q, 0)],
190
  &mha_norm[MINDEX0()],
@@ -215,23 +244,162 @@ void generate_prompt_proc(int vocabulary_len, int context_len, int layer_count,
215
  float* ffn_fc_weight, float* ffn_up_weight,
216
  float* ffn_out_weight, float* out_norm_weight,
217
  float* out_weight, float* k_cache, float* v_cache,
218
  float* logits, int* sequence, int sequence_len) {
219
  __reads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
220
  __reads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
221
  __reads("sequence ~> Matrix1(sequence_len)");
222
  __reads(
223
  "mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
224
  "embedding_dim)");
 
 
 
 
 
 
225
  for (int i = 0; i < sequence_len; i++) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  __xreads("&sequence[MINDEX1(sequence_len, i)] ~> Cell");
227
  const int logits_count = 1;
228
  const int token = sequence[MINDEX1(sequence_len, i)];
 
 
 
 
 
 
 
229
- forward(token, vocabulary_len, context_len, layer_count, q_head_count,
230
- kv_head_count, q_head_per_kv_head_count, embedding_dim, head_dim,
 
231
- q_dim, kv_dim, hidden_dim, epsilon, embedding_weight,
232
- mha_norm_weight, mha_q_weight, mha_k_weight, mha_v_weight,
233
- mha_out_weight, ffn_norm_weight, ffn_fc_weight, ffn_up_weight,
234
- ffn_out_weight, out_norm_weight, out_weight, k_cache, v_cache,
235
- logits, i, logits_count);
236
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  }
21
 
22
  float _sqrtf(float x) {
23
  __pure();
24
  __admitted();
25
  return sqrtf(x);
26
  }
27
 
28
  void rope(int col_count, float* x, int pos) {
29
  __modifies("x ~> Matrix1(col_count)");
30
  for (int j = 0; j < col_count; j += 2) {
31
+ __strict();
32
  __smodifies("x ~> Matrix1(col_count)");
33
  __ghost(assume, "P := is_subrange(j..(j + 1), 0..col_count)");
34
  const __ghost_fn focus_subrange =
35
  __ghost_begin(group_focus_subrange, "sub_range := j..(j + 1)");
36
  float freq = 1.f / _powf(500000.f, (float)j / (float)col_count);
37
  float val = (float)pos * 2.f;
38
  float fcr = _cosf(val);
39
  float fci = _sinf(val);
40
  __ghost([&]() {
41
  __consumes("for i1 in j..(j + 1) -> &x[MINDEX1(col_count, i1)] ~> Cell");
65
 
66
  void softmax(int col_count, int col_stride, float* x) {
67
  __modifies("x ~> Matrix1(col_count)");
68
  __ghost(assume, "P := in_range(0, 0..col_count)");
69
  const __ghost_fn max = __ghost_begin(ro_matrix1_focus, "matrix := x, i := 0");
70
  float max_val = x[MINDEX1(col_count, 0)];
71
  __ghost_end(max);
72
  const __ghost_fn focus_subrange =
73
  __ghost_begin(group_focus_subrange, "sub_range := 1..col_count");
74
  for (int j = 1; j < col_count; j++) {
75
+ __strict();
76
+ __smodifies(
77
+ "Wand(for i1 in 1..col_count -> &x[MINDEX1(col_count, i1)] ~> Cell, x "
78
+ "~> Matrix1(col_count))");
79
+ __smodifies("&max_val ~> Cell");
80
  __xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
81
  if (x[MINDEX1(col_count, j)] > max_val) {
82
  max_val = x[MINDEX1(col_count, j)];
83
  }
84
  }
85
  __ghost_end(focus_subrange);
86
  float sum = 0.f;
87
  for (int j = 0; j < col_count; j++) {
88
+ __strict();
89
+ __smodifies("&sum ~> Cell");
90
+ __sreads("&max_val ~> Cell");
91
  __xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
92
  x[MINDEX1(col_count, j)] = _expf(x[MINDEX1(col_count, j)] - max_val);
93
  sum += x[MINDEX1(col_count, j)];
94
  }
95
  for (int j = 0; j < col_count; j++) {
96
+ __strict();
97
+ __sreads("&sum ~> Cell");
98
  __xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
99
  x[MINDEX1(col_count, j)] /= sum;
100
  }
101
  }
102
 
103
  void rmsnorm(int col_count, float* y, float* x, float* w, float epsilon) {
104
  __writes("y ~> Matrix1(col_count)");
105
  __reads("x ~> Matrix1(col_count)");
106
  __reads("w ~> Matrix1(col_count)");
107
  float ss = 0.f;
108
  for (int j = 0; j < col_count; j++) {
109
+ __strict();
110
  __smodifies("&ss ~> Cell");
111
  __xreads("&x[MINDEX1(col_count, j)] ~> Cell");
112
  ss += x[MINDEX1(col_count, j)] * x[MINDEX1(col_count, j)];
113
  }
114
  ss /= (float)col_count;
115
  ss += epsilon;
116
  ss = 1.f / _sqrtf(ss);
117
  for (int j = 0; j < col_count; j++) {
118
+ __strict();
119
+ __sreads("&ss ~> Cell");
120
  __xwrites("&y[MINDEX1(col_count, j)] ~> Cell");
121
  __xreads("&x[MINDEX1(col_count, j)] ~> Cell");
122
  __xreads("&w[MINDEX1(col_count, j)] ~> Cell");
123
  y[MINDEX1(col_count, j)] =
124
  w[MINDEX1(col_count, j)] * (ss * x[MINDEX1(col_count, j)]);
125
  }
126
  }
127
 
128
  void matvec(int col_count, int red_count, float* x, float* y, float* w) {
129
  __writes("x ~> Matrix1(col_count)");
130
  __reads("y ~> Matrix1(red_count)");
131
  __reads("w ~> Matrix2(col_count, red_count)");
132
  for (int j = 0; j < col_count; j++) {
133
+ __strict();
134
+ __sreads("y ~> Matrix1(red_count)");
135
+ __sreads("w ~> Matrix2(col_count, red_count)");
136
  __xwrites("&x[MINDEX1(col_count, j)] ~> Cell");
137
  x[MINDEX1(col_count, j)] = 0.f;
138
  for (int k = 0; k < red_count; k++) {
139
+ __strict();
140
+ __smodifies("&x[MINDEX1(col_count, j)] ~> Cell");
141
+ __sreads("y ~> Matrix1(red_count)");
142
+ __sreads("w ~> Matrix2(col_count, red_count)");
143
  const __ghost_fn focusy =
144
  __ghost_begin(ro_matrix1_focus, "matrix := y, i := k");
145
  const __ghost_fn focusw =
146
  __ghost_begin(ro_matrix2_focus, "matrix := w, i := j, j := k");
147
  x[MINDEX1(col_count, j)] +=
148
  y[MINDEX1(red_count, k)] * w[MINDEX2(col_count, red_count, j, k)];
149
  __ghost_end(focusy);
150
  __ghost_end(focusw);
151
  }
152
  }
166
  __reads(
167
  "mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
168
  "embedding_dim)");
169
  float* const embedding =
170
  (float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
171
  float* const mha_norm = (float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
172
  float* const mha_q =
173
  (float*)malloc(MSIZE2(q_head_count, head_dim) * sizeof(float));
174
  __ghost(assume, "P := in_range(token, 0..vocabulary_len)");
175
  for (int e = 0; e < embedding_dim; e++) {
176
+ __strict();
177
+ __sreads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
178
  __xwrites("&embedding[MINDEX1(embedding_dim, e)] ~> Cell");
179
  const __ghost_fn focus_embedding_weight = __ghost_begin(
180
  ro_matrix2_focus, "matrix := embedding_weight, i := token, j := e");
181
  embedding[MINDEX1(embedding_dim, e)] =
182
  embedding_weight[MINDEX2(vocabulary_len, embedding_dim, token, e)];
183
  __ghost_end(focus_embedding_weight);
184
  }
185
  __ghost([&]() {
186
  __consumes("embedding ~> Matrix1(embedding_dim)");
187
  __consumes("mha_norm ~> UninitMatrix1(embedding_dim)");
188
  __produces("&embedding[MINDEX0()] ~> Matrix1(embedding_dim)");
189
  __produces("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
190
  __admitted();
191
  });
192
  for (int l = 0; l < layer_count; l++) {
193
+ __strict();
194
+ __smodifies("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
195
+ __smodifies("mha_q ~> UninitMatrix2(q_head_count, head_dim)");
196
+ __sreads("&embedding[MINDEX0()] ~> Matrix1(embedding_dim)");
197
  __xreads(
198
  "for i1 in 0..embedding_dim -> &mha_norm_weight[MINDEX2(layer_count, "
199
  "embedding_dim, l, i1)] ~> Cell");
200
  __xreads(
201
  "for q in 0..q_head_count -> for h in 0..head_dim -> for e in "
202
  "0..embedding_dim -> &mha_q_weight[MINDEX4(layer_count, q_head_count, "
203
  "head_dim, embedding_dim, l, q, h, e)] ~> Cell");
204
  rmsnorm(embedding_dim, &mha_norm[MINDEX0()], &embedding[MINDEX0()],
205
  &mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, 0)],
206
  epsilon);
207
  for (int q = 0; q < q_head_count; q++) {
208
+ __strict();
209
+ __sreads("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
210
  __xmodifies(
211
  "for i1 in 0..head_dim -> &mha_q[MINDEX2(q_head_count, head_dim, q, "
212
  "i1)] ~> UninitCell");
213
  __xreads(
214
  "for h in 0..head_dim -> for e in 0..embedding_dim -> "
215
  "&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
216
  "embedding_dim, l, q, h, e)] ~> Cell");
217
  matvec(head_dim, embedding_dim,
218
  &mha_q[MINDEX2(q_head_count, head_dim, q, 0)],
219
  &mha_norm[MINDEX0()],
244
  float* ffn_fc_weight, float* ffn_up_weight,
245
  float* ffn_out_weight, float* out_norm_weight,
246
  float* out_weight, float* k_cache, float* v_cache,
247
  float* logits, int* sequence, int sequence_len) {
248
  __reads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
249
  __reads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
250
  __reads("sequence ~> Matrix1(sequence_len)");
251
  __reads(
252
  "mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
253
  "embedding_dim)");
254
+ float* const embedding =
255
+ (float*)malloc(MSIZE2(sequence_len, embedding_dim) * sizeof(float));
256
+ float* const mha_norm =
257
+ (float*)malloc(MSIZE2(sequence_len, embedding_dim) * sizeof(float));
258
+ float* const mha_q = (float*)malloc(
259
+ MSIZE3(sequence_len, q_head_count, head_dim) * sizeof(float));
260
  for (int i = 0; i < sequence_len; i++) {
261
+ __strict();
262
+ __sreads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
263
+ __xconsumes(
264
+ "for _v2 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
265
+ "embedding_dim, i, _v2)] ~> UninitCell");
266
+ __xconsumes(
267
+ "for _v1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
268
+ "embedding_dim, i, _v1)] ~> UninitCell");
269
+ __xproduces(
270
+ "&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
271
+ "Matrix1(embedding_dim)");
272
+ __xproduces(
273
+ "&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
274
+ "Matrix1(embedding_dim)");
275
  __xreads("&sequence[MINDEX1(sequence_len, i)] ~> Cell");
276
  const int logits_count = 1;
277
  const int token = sequence[MINDEX1(sequence_len, i)];
278
+ __ghost(assume, "P := in_range(token, 0..vocabulary_len)", "#_8 <- H");
279
+ for (int e = 0; e < embedding_dim; e++) {
280
+ __strict();
281
+ __sreads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
282
+ __xwrites(
283
+ "&embedding[MINDEX2(sequence_len, embedding_dim, i, e)] ~> Cell");
284
+ const __ghost_fn focus_embedding_weight = __ghost_begin(
285
+ ro_matrix2_focus, "matrix := embedding_weight, i := token, j := e");
286
+ embedding[MINDEX2(sequence_len, embedding_dim, i, e)] =
287
+ embedding_weight[MINDEX2(vocabulary_len, embedding_dim, token, e)];
288
+ __ghost_end(focus_embedding_weight);
 
 
 
 
289
  }
290
+ __ghost([&]() {
291
+ __consumes(
292
+ "for i1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
293
+ "embedding_dim, i, i1)] ~> Cell");
294
+ __consumes(
295
+ "for i1 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
296
+ "embedding_dim, i, i1)] ~> UninitCell");
297
+ __produces(
298
+ "&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
299
+ "Matrix1(embedding_dim)");
300
+ __produces(
301
+ "&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
302
+ "Matrix1(embedding_dim)");
303
+ __admitted();
304
+ });
305
+ }
306
+ for (int l = 0; l < layer_count; l++) {
307
+ __strict();
308
+ __smodifies(
309
+ "for i in 0..sequence_len -> &mha_norm[MINDEX2(sequence_len, "
310
+ "embedding_dim, i, 0)] ~> Matrix1(embedding_dim)");
311
+ __smodifies("mha_q ~> UninitMatrix3(sequence_len, q_head_count, head_dim)");
312
+ __sreads(
313
+ "for i in 0..sequence_len -> &embedding[MINDEX2(sequence_len, "
314
+ "embedding_dim, i, 0)] ~> Matrix1(embedding_dim)");
315
+ __sreads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
316
+ __sreads(
317
+ "mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
318
+ "embedding_dim)");
319
+ for (int i = 0; i < sequence_len; i++) {
320
+ __strict();
321
+ __sreads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
322
+ __sreads(
323
+ "mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
324
+ "embedding_dim)");
325
+ __xmodifies(
326
+ "&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
327
+ "Matrix1(embedding_dim)");
328
+ __xmodifies(
329
+ "for i1 in 0..q_head_count -> for i2 in 0..head_dim -> "
330
+ "&mha_q[MINDEX3(sequence_len, q_head_count, head_dim, i, i1, i2)] ~> "
331
+ "UninitCell");
332
+ __xreads(
333
+ "&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
334
+ "Matrix1(embedding_dim)");
335
+ const __ghost_fn __ghost_pair_2 = __ghost_begin(
336
+ ro_group_focus,
337
+ "i := l, items := fun (l: int) -> for i1 in 0..embedding_dim -> "
338
+ "&mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, i1)] ~> "
339
+ "Cell");
340
+ const __ghost_fn __ghost_pair_1 = __ghost_begin(
341
+ ro_group_focus,
342
+ "i := l, items := fun (l: int) -> for q in 0..q_head_count -> for h "
343
+ "in 0..head_dim -> for e in 0..embedding_dim -> "
344
+ "&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
345
+ "embedding_dim, l, q, h, e)] ~> Cell");
346
+ rmsnorm(
347
+ embedding_dim, &mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)],
348
+ &embedding[MINDEX2(sequence_len, embedding_dim, i, 0)],
349
+ &mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, 0)], epsilon);
350
+ for (int q = 0; q < q_head_count; q++) {
351
+ __strict();
352
+ __sreads(
353
+ "&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
354
+ "Matrix1(embedding_dim)");
355
+ __xmodifies(
356
+ "for i1 in 0..head_dim -> &mha_q[MINDEX3(sequence_len, "
357
+ "q_head_count, head_dim, i, q, i1)] ~> UninitCell");
358
+ __xreads(
359
+ "for h in 0..head_dim -> for e in 0..embedding_dim -> "
360
+ "&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
361
+ "embedding_dim, l, q, h, e)] ~> Cell");
362
+ matvec(head_dim, embedding_dim,
363
+ &mha_q[MINDEX3(sequence_len, q_head_count, head_dim, i, q, 0)],
364
+ &mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)],
365
+ &mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim,
366
+ embedding_dim, l, q, 0, 0)]);
367
+ }
368
+ __ghost_end(__ghost_pair_1);
369
+ __ghost_end(__ghost_pair_2);
370
+ }
371
+ }
372
+ for (int i = 0; i < sequence_len; i++) {
373
+ __strict();
374
+ __xconsumes(
375
+ "&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
376
+ "Matrix1(embedding_dim)");
377
+ __xconsumes(
378
+ "&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
379
+ "Matrix1(embedding_dim)");
380
+ __xproduces(
381
+ "for _v2 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
382
+ "embedding_dim, i, _v2)] ~> UninitCell");
383
+ __xproduces(
384
+ "for _v1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
385
+ "embedding_dim, i, _v1)] ~> UninitCell");
386
+ __ghost([&]() {
387
+ __consumes(
388
+ "&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
389
+ "Matrix1(embedding_dim)");
390
+ __consumes(
391
+ "&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
392
+ "Matrix1(embedding_dim)");
393
+ __produces(
394
+ "for i1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
395
+ "embedding_dim, i, i1)] ~> Cell");
396
+ __produces(
397
+ "for i1 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
398
+ "embedding_dim, i, i1)] ~> UninitCell");
399
+ __admitted();
400
+ });
401
+ }
402
+ free(mha_q);
403
+ free(mha_norm);
404
+ free(embedding);
405
  }
 Show: