|
@@ -21,20 +21,21 @@ float _sinf(float x) {
|
|
| 21 |
|
| 22 |
float _sqrtf(float x) {
|
| 23 |
__pure();
|
| 24 |
__admitted();
|
| 25 |
return sqrtf(x);
|
| 26 |
}
|
| 27 |
|
| 28 |
void rope(int col_count, float* x, int pos) {
|
| 29 |
__modifies("x ~> Matrix1(col_count)");
|
| 30 |
for (int j = 0; j < col_count; j += 2) {
|
|
|
|
| 31 |
__smodifies("x ~> Matrix1(col_count)");
|
| 32 |
__ghost(assume, "P := is_subrange(j..(j + 1), 0..col_count)");
|
| 33 |
const __ghost_fn focus_subrange =
|
| 34 |
__ghost_begin(group_focus_subrange, "sub_range := j..(j + 1)");
|
| 35 |
float freq = 1.f / _powf(500000.f, (float)j / (float)col_count);
|
| 36 |
float val = (float)pos * 2.f;
|
| 37 |
float fcr = _cosf(val);
|
| 38 |
float fci = _sinf(val);
|
| 39 |
__ghost([&]() {
|
| 40 |
__consumes("for i1 in j..(j + 1) -> &x[MINDEX1(col_count, i1)] ~> Cell");
|
|
@@ -64,68 +65,88 @@ float _expf(float x) {
|
|
| 64 |
|
| 65 |
void softmax(int col_count, int col_stride, float* x) {
|
| 66 |
__modifies("x ~> Matrix1(col_count)");
|
| 67 |
__ghost(assume, "P := in_range(0, 0..col_count)");
|
| 68 |
const __ghost_fn max = __ghost_begin(ro_matrix1_focus, "matrix := x, i := 0");
|
| 69 |
float max_val = x[MINDEX1(col_count, 0)];
|
| 70 |
__ghost_end(max);
|
| 71 |
const __ghost_fn focus_subrange =
|
| 72 |
__ghost_begin(group_focus_subrange, "sub_range := 1..col_count");
|
| 73 |
for (int j = 1; j < col_count; j++) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
__xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 75 |
if (x[MINDEX1(col_count, j)] > max_val) {
|
| 76 |
max_val = x[MINDEX1(col_count, j)];
|
| 77 |
}
|
| 78 |
}
|
| 79 |
__ghost_end(focus_subrange);
|
| 80 |
float sum = 0.f;
|
| 81 |
for (int j = 0; j < col_count; j++) {
|
|
|
|
|
|
|
|
|
|
| 82 |
__xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 83 |
x[MINDEX1(col_count, j)] = _expf(x[MINDEX1(col_count, j)] - max_val);
|
| 84 |
sum += x[MINDEX1(col_count, j)];
|
| 85 |
}
|
| 86 |
for (int j = 0; j < col_count; j++) {
|
|
|
|
|
|
|
| 87 |
__xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 88 |
x[MINDEX1(col_count, j)] /= sum;
|
| 89 |
}
|
| 90 |
}
|
| 91 |
|
| 92 |
void rmsnorm(int col_count, float* y, float* x, float* w, float epsilon) {
|
| 93 |
__writes("y ~> Matrix1(col_count)");
|
| 94 |
__reads("x ~> Matrix1(col_count)");
|
| 95 |
__reads("w ~> Matrix1(col_count)");
|
| 96 |
float ss = 0.f;
|
| 97 |
for (int j = 0; j < col_count; j++) {
|
|
|
|
| 98 |
__smodifies("&ss ~> Cell");
|
| 99 |
__xreads("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 100 |
ss += x[MINDEX1(col_count, j)] * x[MINDEX1(col_count, j)];
|
| 101 |
}
|
| 102 |
ss /= (float)col_count;
|
| 103 |
ss += epsilon;
|
| 104 |
ss = 1.f / _sqrtf(ss);
|
| 105 |
for (int j = 0; j < col_count; j++) {
|
|
|
|
|
|
|
| 106 |
__xwrites("&y[MINDEX1(col_count, j)] ~> Cell");
|
| 107 |
__xreads("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 108 |
__xreads("&w[MINDEX1(col_count, j)] ~> Cell");
|
| 109 |
y[MINDEX1(col_count, j)] =
|
| 110 |
w[MINDEX1(col_count, j)] * (ss * x[MINDEX1(col_count, j)]);
|
| 111 |
}
|
| 112 |
}
|
| 113 |
|
| 114 |
void matvec(int col_count, int red_count, float* x, float* y, float* w) {
|
| 115 |
__writes("x ~> Matrix1(col_count)");
|
| 116 |
__reads("y ~> Matrix1(red_count)");
|
| 117 |
__reads("w ~> Matrix2(col_count, red_count)");
|
| 118 |
for (int j = 0; j < col_count; j++) {
|
|
|
|
|
|
|
|
|
|
| 119 |
__xwrites("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 120 |
x[MINDEX1(col_count, j)] = 0.f;
|
| 121 |
for (int k = 0; k < red_count; k++) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
const __ghost_fn focusy =
|
| 123 |
__ghost_begin(ro_matrix1_focus, "matrix := y, i := k");
|
| 124 |
const __ghost_fn focusw =
|
| 125 |
__ghost_begin(ro_matrix2_focus, "matrix := w, i := j, j := k");
|
| 126 |
x[MINDEX1(col_count, j)] +=
|
| 127 |
y[MINDEX1(red_count, k)] * w[MINDEX2(col_count, red_count, j, k)];
|
| 128 |
__ghost_end(focusy);
|
| 129 |
__ghost_end(focusw);
|
| 130 |
}
|
| 131 |
}
|
|
@@ -145,46 +166,54 @@ void forward(int token, int vocabulary_len, int context_len, int layer_count,
|
|
| 145 |
__reads(
|
| 146 |
"mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
|
| 147 |
"embedding_dim)");
|
| 148 |
float* const embedding =
|
| 149 |
(float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
|
| 150 |
float* const mha_norm = (float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
|
| 151 |
float* const mha_q =
|
| 152 |
(float*)malloc(MSIZE2(q_head_count, head_dim) * sizeof(float));
|
| 153 |
__ghost(assume, "P := in_range(token, 0..vocabulary_len)");
|
| 154 |
for (int e = 0; e < embedding_dim; e++) {
|
|
|
|
|
|
|
| 155 |
__xwrites("&embedding[MINDEX1(embedding_dim, e)] ~> Cell");
|
| 156 |
const __ghost_fn focus_embedding_weight = __ghost_begin(
|
| 157 |
ro_matrix2_focus, "matrix := embedding_weight, i := token, j := e");
|
| 158 |
embedding[MINDEX1(embedding_dim, e)] =
|
| 159 |
embedding_weight[MINDEX2(vocabulary_len, embedding_dim, token, e)];
|
| 160 |
__ghost_end(focus_embedding_weight);
|
| 161 |
}
|
| 162 |
__ghost([&]() {
|
| 163 |
__consumes("embedding ~> Matrix1(embedding_dim)");
|
| 164 |
__consumes("mha_norm ~> UninitMatrix1(embedding_dim)");
|
| 165 |
__produces("&embedding[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 166 |
__produces("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 167 |
__admitted();
|
| 168 |
});
|
| 169 |
for (int l = 0; l < layer_count; l++) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
__xreads(
|
| 171 |
"for i1 in 0..embedding_dim -> &mha_norm_weight[MINDEX2(layer_count, "
|
| 172 |
"embedding_dim, l, i1)] ~> Cell");
|
| 173 |
__xreads(
|
| 174 |
"for q in 0..q_head_count -> for h in 0..head_dim -> for e in "
|
| 175 |
"0..embedding_dim -> &mha_q_weight[MINDEX4(layer_count, q_head_count, "
|
| 176 |
"head_dim, embedding_dim, l, q, h, e)] ~> Cell");
|
| 177 |
rmsnorm(embedding_dim, &mha_norm[MINDEX0()], &embedding[MINDEX0()],
|
| 178 |
&mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, 0)],
|
| 179 |
epsilon);
|
| 180 |
for (int q = 0; q < q_head_count; q++) {
|
|
|
|
|
|
|
| 181 |
__xmodifies(
|
| 182 |
"for i1 in 0..head_dim -> &mha_q[MINDEX2(q_head_count, head_dim, q, "
|
| 183 |
"i1)] ~> UninitCell");
|
| 184 |
__xreads(
|
| 185 |
"for h in 0..head_dim -> for e in 0..embedding_dim -> "
|
| 186 |
"&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
|
| 187 |
"embedding_dim, l, q, h, e)] ~> Cell");
|
| 188 |
matvec(head_dim, embedding_dim,
|
| 189 |
&mha_q[MINDEX2(q_head_count, head_dim, q, 0)],
|
| 190 |
&mha_norm[MINDEX0()],
|
|
@@ -215,23 +244,162 @@ void generate_prompt_proc(int vocabulary_len, int context_len, int layer_count,
|
|
| 215 |
float* ffn_fc_weight, float* ffn_up_weight,
|
| 216 |
float* ffn_out_weight, float* out_norm_weight,
|
| 217 |
float* out_weight, float* k_cache, float* v_cache,
|
| 218 |
float* logits, int* sequence, int sequence_len) {
|
| 219 |
__reads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
|
| 220 |
__reads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
|
| 221 |
__reads("sequence ~> Matrix1(sequence_len)");
|
| 222 |
__reads(
|
| 223 |
"mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
|
| 224 |
"embedding_dim)");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
for (int i = 0; i < sequence_len; i++) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
__xreads("&sequence[MINDEX1(sequence_len, i)] ~> Cell");
|
| 227 |
const int logits_count = 1;
|
| 228 |
const int token = sequence[MINDEX1(sequence_len, i)];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
-
|
| 230 |
-
|
|
|
|
| 231 |
-
|
| 232 |
-
mha_norm_weight, mha_q_weight, mha_k_weight, mha_v_weight,
|
| 233 |
-
mha_out_weight, ffn_norm_weight, ffn_fc_weight, ffn_up_weight,
|
| 234 |
-
ffn_out_weight, out_norm_weight, out_weight, k_cache, v_cache,
|
| 235 |
-
logits, i, logits_count);
|
| 236 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 21 |
|
| 22 |
float _sqrtf(float x) {
|
| 23 |
__pure();
|
| 24 |
__admitted();
|
| 25 |
return sqrtf(x);
|
| 26 |
}
|
| 27 |
|
| 28 |
void rope(int col_count, float* x, int pos) {
|
| 29 |
__modifies("x ~> Matrix1(col_count)");
|
| 30 |
for (int j = 0; j < col_count; j += 2) {
|
| 31 |
+
__strict();
|
| 32 |
__smodifies("x ~> Matrix1(col_count)");
|
| 33 |
__ghost(assume, "P := is_subrange(j..(j + 1), 0..col_count)");
|
| 34 |
const __ghost_fn focus_subrange =
|
| 35 |
__ghost_begin(group_focus_subrange, "sub_range := j..(j + 1)");
|
| 36 |
float freq = 1.f / _powf(500000.f, (float)j / (float)col_count);
|
| 37 |
float val = (float)pos * 2.f;
|
| 38 |
float fcr = _cosf(val);
|
| 39 |
float fci = _sinf(val);
|
| 40 |
__ghost([&]() {
|
| 41 |
__consumes("for i1 in j..(j + 1) -> &x[MINDEX1(col_count, i1)] ~> Cell");
|
| 65 |
|
| 66 |
void softmax(int col_count, int col_stride, float* x) {
|
| 67 |
__modifies("x ~> Matrix1(col_count)");
|
| 68 |
__ghost(assume, "P := in_range(0, 0..col_count)");
|
| 69 |
const __ghost_fn max = __ghost_begin(ro_matrix1_focus, "matrix := x, i := 0");
|
| 70 |
float max_val = x[MINDEX1(col_count, 0)];
|
| 71 |
__ghost_end(max);
|
| 72 |
const __ghost_fn focus_subrange =
|
| 73 |
__ghost_begin(group_focus_subrange, "sub_range := 1..col_count");
|
| 74 |
for (int j = 1; j < col_count; j++) {
|
| 75 |
+
__strict();
|
| 76 |
+
__smodifies(
|
| 77 |
+
"Wand(for i1 in 1..col_count -> &x[MINDEX1(col_count, i1)] ~> Cell, x "
|
| 78 |
+
"~> Matrix1(col_count))");
|
| 79 |
+
__smodifies("&max_val ~> Cell");
|
| 80 |
__xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 81 |
if (x[MINDEX1(col_count, j)] > max_val) {
|
| 82 |
max_val = x[MINDEX1(col_count, j)];
|
| 83 |
}
|
| 84 |
}
|
| 85 |
__ghost_end(focus_subrange);
|
| 86 |
float sum = 0.f;
|
| 87 |
for (int j = 0; j < col_count; j++) {
|
| 88 |
+
__strict();
|
| 89 |
+
__smodifies("&sum ~> Cell");
|
| 90 |
+
__sreads("&max_val ~> Cell");
|
| 91 |
__xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 92 |
x[MINDEX1(col_count, j)] = _expf(x[MINDEX1(col_count, j)] - max_val);
|
| 93 |
sum += x[MINDEX1(col_count, j)];
|
| 94 |
}
|
| 95 |
for (int j = 0; j < col_count; j++) {
|
| 96 |
+
__strict();
|
| 97 |
+
__sreads("&sum ~> Cell");
|
| 98 |
__xmodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 99 |
x[MINDEX1(col_count, j)] /= sum;
|
| 100 |
}
|
| 101 |
}
|
| 102 |
|
| 103 |
void rmsnorm(int col_count, float* y, float* x, float* w, float epsilon) {
|
| 104 |
__writes("y ~> Matrix1(col_count)");
|
| 105 |
__reads("x ~> Matrix1(col_count)");
|
| 106 |
__reads("w ~> Matrix1(col_count)");
|
| 107 |
float ss = 0.f;
|
| 108 |
for (int j = 0; j < col_count; j++) {
|
| 109 |
+
__strict();
|
| 110 |
__smodifies("&ss ~> Cell");
|
| 111 |
__xreads("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 112 |
ss += x[MINDEX1(col_count, j)] * x[MINDEX1(col_count, j)];
|
| 113 |
}
|
| 114 |
ss /= (float)col_count;
|
| 115 |
ss += epsilon;
|
| 116 |
ss = 1.f / _sqrtf(ss);
|
| 117 |
for (int j = 0; j < col_count; j++) {
|
| 118 |
+
__strict();
|
| 119 |
+
__sreads("&ss ~> Cell");
|
| 120 |
__xwrites("&y[MINDEX1(col_count, j)] ~> Cell");
|
| 121 |
__xreads("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 122 |
__xreads("&w[MINDEX1(col_count, j)] ~> Cell");
|
| 123 |
y[MINDEX1(col_count, j)] =
|
| 124 |
w[MINDEX1(col_count, j)] * (ss * x[MINDEX1(col_count, j)]);
|
| 125 |
}
|
| 126 |
}
|
| 127 |
|
| 128 |
void matvec(int col_count, int red_count, float* x, float* y, float* w) {
|
| 129 |
__writes("x ~> Matrix1(col_count)");
|
| 130 |
__reads("y ~> Matrix1(red_count)");
|
| 131 |
__reads("w ~> Matrix2(col_count, red_count)");
|
| 132 |
for (int j = 0; j < col_count; j++) {
|
| 133 |
+
__strict();
|
| 134 |
+
__sreads("y ~> Matrix1(red_count)");
|
| 135 |
+
__sreads("w ~> Matrix2(col_count, red_count)");
|
| 136 |
__xwrites("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 137 |
x[MINDEX1(col_count, j)] = 0.f;
|
| 138 |
for (int k = 0; k < red_count; k++) {
|
| 139 |
+
__strict();
|
| 140 |
+
__smodifies("&x[MINDEX1(col_count, j)] ~> Cell");
|
| 141 |
+
__sreads("y ~> Matrix1(red_count)");
|
| 142 |
+
__sreads("w ~> Matrix2(col_count, red_count)");
|
| 143 |
const __ghost_fn focusy =
|
| 144 |
__ghost_begin(ro_matrix1_focus, "matrix := y, i := k");
|
| 145 |
const __ghost_fn focusw =
|
| 146 |
__ghost_begin(ro_matrix2_focus, "matrix := w, i := j, j := k");
|
| 147 |
x[MINDEX1(col_count, j)] +=
|
| 148 |
y[MINDEX1(red_count, k)] * w[MINDEX2(col_count, red_count, j, k)];
|
| 149 |
__ghost_end(focusy);
|
| 150 |
__ghost_end(focusw);
|
| 151 |
}
|
| 152 |
}
|
| 166 |
__reads(
|
| 167 |
"mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
|
| 168 |
"embedding_dim)");
|
| 169 |
float* const embedding =
|
| 170 |
(float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
|
| 171 |
float* const mha_norm = (float*)malloc(MSIZE1(embedding_dim) * sizeof(float));
|
| 172 |
float* const mha_q =
|
| 173 |
(float*)malloc(MSIZE2(q_head_count, head_dim) * sizeof(float));
|
| 174 |
__ghost(assume, "P := in_range(token, 0..vocabulary_len)");
|
| 175 |
for (int e = 0; e < embedding_dim; e++) {
|
| 176 |
+
__strict();
|
| 177 |
+
__sreads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
|
| 178 |
__xwrites("&embedding[MINDEX1(embedding_dim, e)] ~> Cell");
|
| 179 |
const __ghost_fn focus_embedding_weight = __ghost_begin(
|
| 180 |
ro_matrix2_focus, "matrix := embedding_weight, i := token, j := e");
|
| 181 |
embedding[MINDEX1(embedding_dim, e)] =
|
| 182 |
embedding_weight[MINDEX2(vocabulary_len, embedding_dim, token, e)];
|
| 183 |
__ghost_end(focus_embedding_weight);
|
| 184 |
}
|
| 185 |
__ghost([&]() {
|
| 186 |
__consumes("embedding ~> Matrix1(embedding_dim)");
|
| 187 |
__consumes("mha_norm ~> UninitMatrix1(embedding_dim)");
|
| 188 |
__produces("&embedding[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 189 |
__produces("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 190 |
__admitted();
|
| 191 |
});
|
| 192 |
for (int l = 0; l < layer_count; l++) {
|
| 193 |
+
__strict();
|
| 194 |
+
__smodifies("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 195 |
+
__smodifies("mha_q ~> UninitMatrix2(q_head_count, head_dim)");
|
| 196 |
+
__sreads("&embedding[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 197 |
__xreads(
|
| 198 |
"for i1 in 0..embedding_dim -> &mha_norm_weight[MINDEX2(layer_count, "
|
| 199 |
"embedding_dim, l, i1)] ~> Cell");
|
| 200 |
__xreads(
|
| 201 |
"for q in 0..q_head_count -> for h in 0..head_dim -> for e in "
|
| 202 |
"0..embedding_dim -> &mha_q_weight[MINDEX4(layer_count, q_head_count, "
|
| 203 |
"head_dim, embedding_dim, l, q, h, e)] ~> Cell");
|
| 204 |
rmsnorm(embedding_dim, &mha_norm[MINDEX0()], &embedding[MINDEX0()],
|
| 205 |
&mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, 0)],
|
| 206 |
epsilon);
|
| 207 |
for (int q = 0; q < q_head_count; q++) {
|
| 208 |
+
__strict();
|
| 209 |
+
__sreads("&mha_norm[MINDEX0()] ~> Matrix1(embedding_dim)");
|
| 210 |
__xmodifies(
|
| 211 |
"for i1 in 0..head_dim -> &mha_q[MINDEX2(q_head_count, head_dim, q, "
|
| 212 |
"i1)] ~> UninitCell");
|
| 213 |
__xreads(
|
| 214 |
"for h in 0..head_dim -> for e in 0..embedding_dim -> "
|
| 215 |
"&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
|
| 216 |
"embedding_dim, l, q, h, e)] ~> Cell");
|
| 217 |
matvec(head_dim, embedding_dim,
|
| 218 |
&mha_q[MINDEX2(q_head_count, head_dim, q, 0)],
|
| 219 |
&mha_norm[MINDEX0()],
|
| 244 |
float* ffn_fc_weight, float* ffn_up_weight,
|
| 245 |
float* ffn_out_weight, float* out_norm_weight,
|
| 246 |
float* out_weight, float* k_cache, float* v_cache,
|
| 247 |
float* logits, int* sequence, int sequence_len) {
|
| 248 |
__reads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
|
| 249 |
__reads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
|
| 250 |
__reads("sequence ~> Matrix1(sequence_len)");
|
| 251 |
__reads(
|
| 252 |
"mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
|
| 253 |
"embedding_dim)");
|
| 254 |
+
float* const embedding =
|
| 255 |
+
(float*)malloc(MSIZE2(sequence_len, embedding_dim) * sizeof(float));
|
| 256 |
+
float* const mha_norm =
|
| 257 |
+
(float*)malloc(MSIZE2(sequence_len, embedding_dim) * sizeof(float));
|
| 258 |
+
float* const mha_q = (float*)malloc(
|
| 259 |
+
MSIZE3(sequence_len, q_head_count, head_dim) * sizeof(float));
|
| 260 |
for (int i = 0; i < sequence_len; i++) {
|
| 261 |
+
__strict();
|
| 262 |
+
__sreads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
|
| 263 |
+
__xconsumes(
|
| 264 |
+
"for _v2 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
|
| 265 |
+
"embedding_dim, i, _v2)] ~> UninitCell");
|
| 266 |
+
__xconsumes(
|
| 267 |
+
"for _v1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
|
| 268 |
+
"embedding_dim, i, _v1)] ~> UninitCell");
|
| 269 |
+
__xproduces(
|
| 270 |
+
"&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 271 |
+
"Matrix1(embedding_dim)");
|
| 272 |
+
__xproduces(
|
| 273 |
+
"&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 274 |
+
"Matrix1(embedding_dim)");
|
| 275 |
__xreads("&sequence[MINDEX1(sequence_len, i)] ~> Cell");
|
| 276 |
const int logits_count = 1;
|
| 277 |
const int token = sequence[MINDEX1(sequence_len, i)];
|
| 278 |
+
__ghost(assume, "P := in_range(token, 0..vocabulary_len)", "#_8 <- H");
|
| 279 |
+
for (int e = 0; e < embedding_dim; e++) {
|
| 280 |
+
__strict();
|
| 281 |
+
__sreads("embedding_weight ~> Matrix2(vocabulary_len, embedding_dim)");
|
| 282 |
+
__xwrites(
|
| 283 |
+
"&embedding[MINDEX2(sequence_len, embedding_dim, i, e)] ~> Cell");
|
| 284 |
+
const __ghost_fn focus_embedding_weight = __ghost_begin(
|
| 285 |
+
ro_matrix2_focus, "matrix := embedding_weight, i := token, j := e");
|
| 286 |
+
embedding[MINDEX2(sequence_len, embedding_dim, i, e)] =
|
| 287 |
+
embedding_weight[MINDEX2(vocabulary_len, embedding_dim, token, e)];
|
| 288 |
+
__ghost_end(focus_embedding_weight);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
}
|
| 290 |
+
__ghost([&]() {
|
| 291 |
+
__consumes(
|
| 292 |
+
"for i1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
|
| 293 |
+
"embedding_dim, i, i1)] ~> Cell");
|
| 294 |
+
__consumes(
|
| 295 |
+
"for i1 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
|
| 296 |
+
"embedding_dim, i, i1)] ~> UninitCell");
|
| 297 |
+
__produces(
|
| 298 |
+
"&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 299 |
+
"Matrix1(embedding_dim)");
|
| 300 |
+
__produces(
|
| 301 |
+
"&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 302 |
+
"Matrix1(embedding_dim)");
|
| 303 |
+
__admitted();
|
| 304 |
+
});
|
| 305 |
+
}
|
| 306 |
+
for (int l = 0; l < layer_count; l++) {
|
| 307 |
+
__strict();
|
| 308 |
+
__smodifies(
|
| 309 |
+
"for i in 0..sequence_len -> &mha_norm[MINDEX2(sequence_len, "
|
| 310 |
+
"embedding_dim, i, 0)] ~> Matrix1(embedding_dim)");
|
| 311 |
+
__smodifies("mha_q ~> UninitMatrix3(sequence_len, q_head_count, head_dim)");
|
| 312 |
+
__sreads(
|
| 313 |
+
"for i in 0..sequence_len -> &embedding[MINDEX2(sequence_len, "
|
| 314 |
+
"embedding_dim, i, 0)] ~> Matrix1(embedding_dim)");
|
| 315 |
+
__sreads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
|
| 316 |
+
__sreads(
|
| 317 |
+
"mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
|
| 318 |
+
"embedding_dim)");
|
| 319 |
+
for (int i = 0; i < sequence_len; i++) {
|
| 320 |
+
__strict();
|
| 321 |
+
__sreads("mha_norm_weight ~> Matrix2(layer_count, embedding_dim)");
|
| 322 |
+
__sreads(
|
| 323 |
+
"mha_q_weight ~> Matrix4(layer_count, q_head_count, head_dim, "
|
| 324 |
+
"embedding_dim)");
|
| 325 |
+
__xmodifies(
|
| 326 |
+
"&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 327 |
+
"Matrix1(embedding_dim)");
|
| 328 |
+
__xmodifies(
|
| 329 |
+
"for i1 in 0..q_head_count -> for i2 in 0..head_dim -> "
|
| 330 |
+
"&mha_q[MINDEX3(sequence_len, q_head_count, head_dim, i, i1, i2)] ~> "
|
| 331 |
+
"UninitCell");
|
| 332 |
+
__xreads(
|
| 333 |
+
"&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 334 |
+
"Matrix1(embedding_dim)");
|
| 335 |
+
const __ghost_fn __ghost_pair_2 = __ghost_begin(
|
| 336 |
+
ro_group_focus,
|
| 337 |
+
"i := l, items := fun (l: int) -> for i1 in 0..embedding_dim -> "
|
| 338 |
+
"&mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, i1)] ~> "
|
| 339 |
+
"Cell");
|
| 340 |
+
const __ghost_fn __ghost_pair_1 = __ghost_begin(
|
| 341 |
+
ro_group_focus,
|
| 342 |
+
"i := l, items := fun (l: int) -> for q in 0..q_head_count -> for h "
|
| 343 |
+
"in 0..head_dim -> for e in 0..embedding_dim -> "
|
| 344 |
+
"&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
|
| 345 |
+
"embedding_dim, l, q, h, e)] ~> Cell");
|
| 346 |
+
rmsnorm(
|
| 347 |
+
embedding_dim, &mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)],
|
| 348 |
+
&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)],
|
| 349 |
+
&mha_norm_weight[MINDEX2(layer_count, embedding_dim, l, 0)], epsilon);
|
| 350 |
+
for (int q = 0; q < q_head_count; q++) {
|
| 351 |
+
__strict();
|
| 352 |
+
__sreads(
|
| 353 |
+
"&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 354 |
+
"Matrix1(embedding_dim)");
|
| 355 |
+
__xmodifies(
|
| 356 |
+
"for i1 in 0..head_dim -> &mha_q[MINDEX3(sequence_len, "
|
| 357 |
+
"q_head_count, head_dim, i, q, i1)] ~> UninitCell");
|
| 358 |
+
__xreads(
|
| 359 |
+
"for h in 0..head_dim -> for e in 0..embedding_dim -> "
|
| 360 |
+
"&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim, "
|
| 361 |
+
"embedding_dim, l, q, h, e)] ~> Cell");
|
| 362 |
+
matvec(head_dim, embedding_dim,
|
| 363 |
+
&mha_q[MINDEX3(sequence_len, q_head_count, head_dim, i, q, 0)],
|
| 364 |
+
&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)],
|
| 365 |
+
&mha_q_weight[MINDEX4(layer_count, q_head_count, head_dim,
|
| 366 |
+
embedding_dim, l, q, 0, 0)]);
|
| 367 |
+
}
|
| 368 |
+
__ghost_end(__ghost_pair_1);
|
| 369 |
+
__ghost_end(__ghost_pair_2);
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
for (int i = 0; i < sequence_len; i++) {
|
| 373 |
+
__strict();
|
| 374 |
+
__xconsumes(
|
| 375 |
+
"&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 376 |
+
"Matrix1(embedding_dim)");
|
| 377 |
+
__xconsumes(
|
| 378 |
+
"&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 379 |
+
"Matrix1(embedding_dim)");
|
| 380 |
+
__xproduces(
|
| 381 |
+
"for _v2 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
|
| 382 |
+
"embedding_dim, i, _v2)] ~> UninitCell");
|
| 383 |
+
__xproduces(
|
| 384 |
+
"for _v1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
|
| 385 |
+
"embedding_dim, i, _v1)] ~> UninitCell");
|
| 386 |
+
__ghost([&]() {
|
| 387 |
+
__consumes(
|
| 388 |
+
"&embedding[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 389 |
+
"Matrix1(embedding_dim)");
|
| 390 |
+
__consumes(
|
| 391 |
+
"&mha_norm[MINDEX2(sequence_len, embedding_dim, i, 0)] ~> "
|
| 392 |
+
"Matrix1(embedding_dim)");
|
| 393 |
+
__produces(
|
| 394 |
+
"for i1 in 0..embedding_dim -> &embedding[MINDEX2(sequence_len, "
|
| 395 |
+
"embedding_dim, i, i1)] ~> Cell");
|
| 396 |
+
__produces(
|
| 397 |
+
"for i1 in 0..embedding_dim -> &mha_norm[MINDEX2(sequence_len, "
|
| 398 |
+
"embedding_dim, i, i1)] ~> UninitCell");
|
| 399 |
+
__admitted();
|
| 400 |
+
});
|
| 401 |
+
}
|
| 402 |
+
free(mha_q);
|
| 403 |
+
free(mha_norm);
|
| 404 |
+
free(embedding);
|
| 405 |
}
|