diff --git a/transformer_vm/model/transformer.cpp b/transformer_vm/model/transformer.cpp index b287f61..27c5ae3 100644 --- a/transformer_vm/model/transformer.cpp +++ b/transformer_vm/model/transformer.cpp @@ -22,6 +22,11 @@ #ifdef __APPLE__ #define ACCELERATE_NEW_LAPACK #include +#elif __has_include() +#define HAS_CBLAS +extern "C" { +#include +} #endif #include "hull2d_cht.h" @@ -67,6 +72,7 @@ struct Model { std::vector> attn_erase; std::vector> ffn_erase; std::vector> head_tb; + std::vector> head_type; // 1 = passthrough, 0 = lookup }; static void load(Model& m, const char* path) { @@ -138,6 +144,20 @@ static void load(Model& m, const char* path) { } } } + + int32_t has_ht = 0; + if (fread(&has_ht, 4, 1, f) == 1 && has_ht) { + m.head_type.resize(L); + int H = m.H; + for (int l = 0; l < L; l++) { + m.head_type[l].resize(H, 0); + for (int h = 0; h < H; h++) { + int32_t v; + fread(&v, 4, 1, f); + m.head_type[l][h] = v; + } + } + } fclose(f); m.head_sp.build(m.head, V, D); @@ -158,7 +178,7 @@ static inline void add_position_encoding(double* x, int pos) { static inline void matvec(const double* __restrict__ W, const double* __restrict__ x, double* __restrict__ y, int rows, int cols) { -#if defined(__APPLE__) && !defined(NO_BLAS) +#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS) cblas_dgemv(CblasRowMajor, CblasNoTrans, rows, cols, 1.0, W, cols, x, 1, 0.0, y, 1); #else @@ -211,6 +231,7 @@ int main(int argc, char** argv) { int passed = 0, failed = 0, skipped = 0; long total_tok = 0, total_ops = 0; double total_time = 0, t_proj = 0, t_hull = 0, t_head = 0; + std::vector last_ids; // save last program's tokens for batched test for (int ai = 2; ai < argc; ai++) { if (strstr(argv[ai], "_ref")) continue; @@ -353,6 +374,7 @@ int main(int argc, char** argv) { auto t1 = Clock::now(); double dt = secs(t0, t1); int nt = (int)ids.size(), no = 0; + last_ids = ids; // save for batched test std::string output_bytes; for (int i : ids) { const auto& s = m.name[i]; @@ -451,6 +473,150 @@ int main(int argc, char** argv) { } } + // ── Batched verification: dgemm projections + sequential hull ──── + if (total_tok > 0 && !brute && !last_ids.empty()) { + int T = (int)last_ids.size(); + printf("\n── Batched verify (%d tokens) ──\n", T); + + std::vector X(T * D, 0.0); + for (int t = 0; t < T; t++) { + const double* e = m.emb + last_ids[t] * D; + std::copy(e, e + D, &X[t * D]); + add_position_encoding(&X[t * D], t); + } + + // Fresh hulls for verification + std::vector vhulls(L * H); + + std::vector QKV(T * 3*D), HO(T * D), AO(T * D); + std::vector FF(T * 2*F), GV(T * F), FO(T * D); + int vseq = 0; + double bt_proj = 0, bt_hull = 0; + + auto bstart = Clock::now(); + for (int l = 0; l < L; l++) { + const auto& ly = m.ly[l]; + + // Batch QKV: dgemm X[T,D] @ W_qkv^T[D,3D] -> QKV[T,3D] + auto pa = Clock::now(); +#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS) + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + T, 3*D, D, 1.0, X.data(), D, ly.qkv, D, 0.0, QKV.data(), 3*D); +#else + for (int t = 0; t < T; t++) + matvec(ly.qkv, &X[t*D], &QKV[t*3*D], 3*D, D); +#endif + auto pb = Clock::now(); + + // Attention: passthrough → V[t], gather → V[round(q)], hull → search. + std::fill(HO.begin(), HO.end(), 0.0); + int vseq_base = vseq; + #pragma omp parallel for schedule(static) + for (int h = 0; h < H; h++) { + int ht = (!m.head_type.empty()) ? m.head_type[l][h] : 0; + if (ht == 1) { + // Passthrough: output = V[t]. + for (int t = 0; t < T; t++) { + HO[t*D + h*2] = QKV[t*3*D + 2*D + h*2]; + HO[t*D + h*2 + 1] = QKV[t*3*D + 2*D + h*2 + 1]; + } + } else if (ht == 2) { + // Gather: position-keyed lookup. q_1d = qx / (HARD_K * sqrt(2) * 2). + // But we don't know HARD_K here. Use the ratio: + // score = qx*kx + qy*ky = qx*2s + qy*(-s²+ε) + // Maximized at s = qx/qy (from d/ds = 2*qx - 2s*qy = 0). + // So the gather index is round(qx / qy). + for (int t = 0; t < T; t++) { + double qx = QKV[t*3*D + h*2]; + double qy = QKV[t*3*D + h*2 + 1]; + int idx = (qy != 0.0) ? (int)std::round(qx / qy) : 0; + if (idx < 0) idx = 0; + if (idx >= T) idx = T - 1; + if (idx > t) idx = t; // causality + HO[t*D + h*2] = QKV[idx*3*D + 2*D + h*2]; + HO[t*D + h*2 + 1] = QKV[idx*3*D + 2*D + h*2 + 1]; + } + } else { + // Lookup: use hull. + TieBreak tb = (!m.head_tb.empty()) ? m.head_tb[l][h] : TieBreak::AVERAGE; + for (int t = 0; t < T; t++) { + double *k = &QKV[t*3*D + D], *v = k + D, *q = &QKV[t*3*D]; + vhulls[l*H+h].insert(&k[h*2], &v[h*2], vseq_base + t); + vhulls[l*H+h].query(&q[h*2], tb, &HO[t*D + h*2]); + } + } + } + vseq += T; + auto pc = Clock::now(); + + // Batch out projection: dgemm +#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS) + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + T, D, D, 1.0, HO.data(), D, ly.out, D, 0.0, AO.data(), D); +#else + for (int t = 0; t < T; t++) + matvec(ly.out, &HO[t*D], &AO[t*D], D, D); +#endif + for (int i = 0; i < T*D; i++) X[i] += AO[i]; + + // Erase attention slots + if (!m.attn_erase.empty()) + for (int s : m.attn_erase[l]) + for (int t = 0; t < T; t++) X[t*D + s] = 0.0; + + // Batch FFN: dgemm +#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS) + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + T, 2*F, D, 1.0, X.data(), D, ly.fi, D, 0.0, FF.data(), 2*F); +#else + for (int t = 0; t < T; t++) + matvec(ly.fi, &X[t*D], &FF[t*2*F], 2*F, D); +#endif + for (int t = 0; t < T; t++) + for (int j = 0; j < F; j++) + GV[t*F+j] = (FF[t*2*F+j] > 0 ? FF[t*2*F+j] : 0.0) * FF[t*2*F+F+j]; +#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS) + cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + T, D, F, 1.0, GV.data(), F, ly.fo, F, 0.0, FO.data(), D); +#else + for (int t = 0; t < T; t++) + matvec(ly.fo, &GV[t*F], &FO[t*D], D, F); +#endif + for (int i = 0; i < T*D; i++) X[i] += FO[i]; + + // Erase FFN slots + if (!m.ffn_erase.empty()) + for (int s : m.ffn_erase[l]) + for (int t = 0; t < T; t++) X[t*D + s] = 0.0; + + auto pd = Clock::now(); + bt_proj += secs(pa, pb) + secs(pc, pd); + bt_hull += secs(pb, pc); + } + auto bend = Clock::now(); + double bt_total = secs(bstart, bend); + + // Verify: check argmax at last position + const auto& sp = m.head_sp; + int last = T - 2; // second-to-last position predicts last token + double bs = -1e300; int best = 0; + for (int i = 0; i < sp.rows; i++) { + double s = 0; + for (int k = sp.ptr[i]; k < sp.ptr[i+1]; k++) + s += sp.val[k] * X[last*D + sp.col[k]]; + if (s > bs) { bs = s; best = i; } + } + bool match = (best == last_ids[T-1]); + + printf(" Batched total: %.3fs (%7.0f tok/s)\n", bt_total, T/bt_total); + printf(" Batched proj: %.3fs (%4.1f%%)\n", bt_proj, 100*bt_proj/bt_total); + printf(" Batched hull: %.3fs (%4.1f%%)\n", bt_hull, 100*bt_hull/bt_total); + printf(" Spot check: %s (pos %d: pred=%s, actual=%s)\n", + match ? "OK" : "MISMATCH", last, + m.name[best].c_str(), m.name[last_ids[T-1]].c_str()); + printf(" Speedup vs sequential: %.2fx\n", total_time / bt_total); + } + printf("\n%d passed, %d failed, %d no-ref\n", passed, failed, skipped); if (total_time > 0) { double t_misc = total_time - t_proj - t_hull - t_head; diff --git a/transformer_vm/model/weights.py b/transformer_vm/model/weights.py index 197e9b5..cf93bd8 100644 --- a/transformer_vm/model/weights.py +++ b/transformer_vm/model/weights.py @@ -565,6 +565,36 @@ def expr_to_tensor(expr): model.ffn_erase = [sorted(erased_at[(li, 3)]) for li in range(n_layers)] model.head_tiebreak = all_tiebreak + # Head type: 0=lookup (hull), 1=passthrough (V[t]), 2=gather (V[round(q)]) + from transformer_vm.graph.core import InputDimension as _ID + from transformer_vm.graph.core import _all_lookups + all_head_type = [] + for li in range(n_layers): + layer_ht = [0] * n_heads + if li in head_map: + for h, info in head_map[li].items(): + if info.get("type") == "passthrough": + layer_ht[h] = 1 + elif info.get("type") == "lookup": + # Check if the lookup keys on position (1D key = position) + lu_name = info.get("lookup", "") + # Find the LookUp object by matching name + for lu in pg.all_lookups if program_graph else _all_lookups: + if (lu.name or f"lookup_{lu.id}") == lu_name: + kx = lu.key_exprs_2d[0] + # Position-keyed if kx = 2*position (one term, the position dim) + if len(kx.terms) == 1: + dim, coeff = next(iter(kx.terms.items())) + if isinstance(dim, _ID) and dim.name == "position" and coeff == 2: + layer_ht[h] = 2 + break + all_head_type.append(layer_ht) + model.head_type = all_head_type + n_pt = sum(v == 1 for layer in all_head_type for v in layer) + n_ga = sum(v == 2 for layer in all_head_type for v in layer) + n_hu = sum(v == 0 for layer in all_head_type for v in layer) + logger.info(" head_type: %d passthrough, %d gather, %d hull", n_pt, n_ga, n_hu) + n_persist = sum(len(p1) + len(p2) for _, p1, _, p2 in std_layers) logger.info( "Built model: d_model=%d, n_layers=%d, n_heads=%d, d_ffn=%d, erase=%s", @@ -697,6 +727,14 @@ def W(t): for h in range(H): f.write(struct.pack("