Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 167 additions & 1 deletion transformer_vm/model/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#ifdef __APPLE__
#define ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#elif __has_include(<cblas.h>)
#define HAS_CBLAS
extern "C" {
#include <cblas.h>
}
#endif

#include "hull2d_cht.h"
Expand Down Expand Up @@ -67,6 +72,7 @@ struct Model {
std::vector<std::vector<int>> attn_erase;
std::vector<std::vector<int>> ffn_erase;
std::vector<std::vector<TieBreak>> head_tb;
std::vector<std::vector<int>> head_type; // 1 = passthrough, 0 = lookup
};

static void load(Model& m, const char* path) {
Expand Down Expand Up @@ -138,6 +144,20 @@ static void load(Model& m, const char* path) {
}
}
}

int32_t has_ht = 0;
if (fread(&has_ht, 4, 1, f) == 1 && has_ht) {
m.head_type.resize(L);
int H = m.H;
for (int l = 0; l < L; l++) {
m.head_type[l].resize(H, 0);
for (int h = 0; h < H; h++) {
int32_t v;
fread(&v, 4, 1, f);
m.head_type[l][h] = v;
}
}
}
fclose(f);

m.head_sp.build(m.head, V, D);
Expand All @@ -158,7 +178,7 @@ static inline void add_position_encoding(double* x, int pos) {
static inline void matvec(const double* __restrict__ W,
const double* __restrict__ x,
double* __restrict__ y, int rows, int cols) {
#if defined(__APPLE__) && !defined(NO_BLAS)
#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS)
cblas_dgemv(CblasRowMajor, CblasNoTrans, rows, cols,
1.0, W, cols, x, 1, 0.0, y, 1);
#else
Expand Down Expand Up @@ -211,6 +231,7 @@ int main(int argc, char** argv) {
int passed = 0, failed = 0, skipped = 0;
long total_tok = 0, total_ops = 0;
double total_time = 0, t_proj = 0, t_hull = 0, t_head = 0;
std::vector<int> last_ids; // save last program's tokens for batched test

for (int ai = 2; ai < argc; ai++) {
if (strstr(argv[ai], "_ref")) continue;
Expand Down Expand Up @@ -353,6 +374,7 @@ int main(int argc, char** argv) {
auto t1 = Clock::now();
double dt = secs(t0, t1);
int nt = (int)ids.size(), no = 0;
last_ids = ids; // save for batched test
std::string output_bytes;
for (int i : ids) {
const auto& s = m.name[i];
Expand Down Expand Up @@ -451,6 +473,150 @@ int main(int argc, char** argv) {
}
}

// ── Batched verification: dgemm projections + sequential hull ────
if (total_tok > 0 && !brute && !last_ids.empty()) {
int T = (int)last_ids.size();
printf("\n── Batched verify (%d tokens) ──\n", T);

std::vector<double> X(T * D, 0.0);
for (int t = 0; t < T; t++) {
const double* e = m.emb + last_ids[t] * D;
std::copy(e, e + D, &X[t * D]);
add_position_encoding(&X[t * D], t);
}

// Fresh hulls for verification
std::vector<HardAttentionHead> vhulls(L * H);

std::vector<double> QKV(T * 3*D), HO(T * D), AO(T * D);
std::vector<double> FF(T * 2*F), GV(T * F), FO(T * D);
int vseq = 0;
double bt_proj = 0, bt_hull = 0;

auto bstart = Clock::now();
for (int l = 0; l < L; l++) {
const auto& ly = m.ly[l];

// Batch QKV: dgemm X[T,D] @ W_qkv^T[D,3D] -> QKV[T,3D]
auto pa = Clock::now();
#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS)
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
T, 3*D, D, 1.0, X.data(), D, ly.qkv, D, 0.0, QKV.data(), 3*D);
#else
for (int t = 0; t < T; t++)
matvec(ly.qkv, &X[t*D], &QKV[t*3*D], 3*D, D);
#endif
auto pb = Clock::now();

// Attention: passthrough → V[t], gather → V[round(q)], hull → search.
std::fill(HO.begin(), HO.end(), 0.0);
int vseq_base = vseq;
#pragma omp parallel for schedule(static)
for (int h = 0; h < H; h++) {
int ht = (!m.head_type.empty()) ? m.head_type[l][h] : 0;
if (ht == 1) {
// Passthrough: output = V[t].
for (int t = 0; t < T; t++) {
HO[t*D + h*2] = QKV[t*3*D + 2*D + h*2];
HO[t*D + h*2 + 1] = QKV[t*3*D + 2*D + h*2 + 1];
}
} else if (ht == 2) {
// Gather: position-keyed lookup. q_1d = qx / (HARD_K * sqrt(2) * 2).
// But we don't know HARD_K here. Use the ratio:
// score = qx*kx + qy*ky = qx*2s + qy*(-s²+ε)
// Maximized at s = qx/qy (from d/ds = 2*qx - 2s*qy = 0).
// So the gather index is round(qx / qy).
for (int t = 0; t < T; t++) {
double qx = QKV[t*3*D + h*2];
double qy = QKV[t*3*D + h*2 + 1];
int idx = (qy != 0.0) ? (int)std::round(qx / qy) : 0;
if (idx < 0) idx = 0;
if (idx >= T) idx = T - 1;
if (idx > t) idx = t; // causality
HO[t*D + h*2] = QKV[idx*3*D + 2*D + h*2];
HO[t*D + h*2 + 1] = QKV[idx*3*D + 2*D + h*2 + 1];
}
} else {
// Lookup: use hull.
TieBreak tb = (!m.head_tb.empty()) ? m.head_tb[l][h] : TieBreak::AVERAGE;
for (int t = 0; t < T; t++) {
double *k = &QKV[t*3*D + D], *v = k + D, *q = &QKV[t*3*D];
vhulls[l*H+h].insert(&k[h*2], &v[h*2], vseq_base + t);
vhulls[l*H+h].query(&q[h*2], tb, &HO[t*D + h*2]);
}
}
}
vseq += T;
auto pc = Clock::now();

// Batch out projection: dgemm
#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS)
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
T, D, D, 1.0, HO.data(), D, ly.out, D, 0.0, AO.data(), D);
#else
for (int t = 0; t < T; t++)
matvec(ly.out, &HO[t*D], &AO[t*D], D, D);
#endif
for (int i = 0; i < T*D; i++) X[i] += AO[i];

// Erase attention slots
if (!m.attn_erase.empty())
for (int s : m.attn_erase[l])
for (int t = 0; t < T; t++) X[t*D + s] = 0.0;

// Batch FFN: dgemm
#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS)
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
T, 2*F, D, 1.0, X.data(), D, ly.fi, D, 0.0, FF.data(), 2*F);
#else
for (int t = 0; t < T; t++)
matvec(ly.fi, &X[t*D], &FF[t*2*F], 2*F, D);
#endif
for (int t = 0; t < T; t++)
for (int j = 0; j < F; j++)
GV[t*F+j] = (FF[t*2*F+j] > 0 ? FF[t*2*F+j] : 0.0) * FF[t*2*F+F+j];
#if (defined(__APPLE__) || defined(HAS_CBLAS)) && !defined(NO_BLAS)
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
T, D, F, 1.0, GV.data(), F, ly.fo, F, 0.0, FO.data(), D);
#else
for (int t = 0; t < T; t++)
matvec(ly.fo, &GV[t*F], &FO[t*D], D, F);
#endif
for (int i = 0; i < T*D; i++) X[i] += FO[i];

// Erase FFN slots
if (!m.ffn_erase.empty())
for (int s : m.ffn_erase[l])
for (int t = 0; t < T; t++) X[t*D + s] = 0.0;

auto pd = Clock::now();
bt_proj += secs(pa, pb) + secs(pc, pd);
bt_hull += secs(pb, pc);
}
auto bend = Clock::now();
double bt_total = secs(bstart, bend);

// Verify: check argmax at last position
const auto& sp = m.head_sp;
int last = T - 2; // second-to-last position predicts last token
double bs = -1e300; int best = 0;
for (int i = 0; i < sp.rows; i++) {
double s = 0;
for (int k = sp.ptr[i]; k < sp.ptr[i+1]; k++)
s += sp.val[k] * X[last*D + sp.col[k]];
if (s > bs) { bs = s; best = i; }
}
bool match = (best == last_ids[T-1]);

printf(" Batched total: %.3fs (%7.0f tok/s)\n", bt_total, T/bt_total);
printf(" Batched proj: %.3fs (%4.1f%%)\n", bt_proj, 100*bt_proj/bt_total);
printf(" Batched hull: %.3fs (%4.1f%%)\n", bt_hull, 100*bt_hull/bt_total);
printf(" Spot check: %s (pos %d: pred=%s, actual=%s)\n",
match ? "OK" : "MISMATCH", last,
m.name[best].c_str(), m.name[last_ids[T-1]].c_str());
printf(" Speedup vs sequential: %.2fx\n", total_time / bt_total);
}

printf("\n%d passed, %d failed, %d no-ref\n", passed, failed, skipped);
if (total_time > 0) {
double t_misc = total_time - t_proj - t_hull - t_head;
Expand Down
48 changes: 48 additions & 0 deletions transformer_vm/model/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,36 @@ def expr_to_tensor(expr):
model.ffn_erase = [sorted(erased_at[(li, 3)]) for li in range(n_layers)]
model.head_tiebreak = all_tiebreak

# Head type: 0=lookup (hull), 1=passthrough (V[t]), 2=gather (V[round(q)])
from transformer_vm.graph.core import InputDimension as _ID
from transformer_vm.graph.core import _all_lookups
all_head_type = []
for li in range(n_layers):
layer_ht = [0] * n_heads
if li in head_map:
for h, info in head_map[li].items():
if info.get("type") == "passthrough":
layer_ht[h] = 1
elif info.get("type") == "lookup":
# Check if the lookup keys on position (1D key = position)
lu_name = info.get("lookup", "")
# Find the LookUp object by matching name
for lu in pg.all_lookups if program_graph else _all_lookups:
if (lu.name or f"lookup_{lu.id}") == lu_name:
kx = lu.key_exprs_2d[0]
# Position-keyed if kx = 2*position (one term, the position dim)
if len(kx.terms) == 1:
dim, coeff = next(iter(kx.terms.items()))
if isinstance(dim, _ID) and dim.name == "position" and coeff == 2:
layer_ht[h] = 2
break
all_head_type.append(layer_ht)
model.head_type = all_head_type
n_pt = sum(v == 1 for layer in all_head_type for v in layer)
n_ga = sum(v == 2 for layer in all_head_type for v in layer)
n_hu = sum(v == 0 for layer in all_head_type for v in layer)
logger.info(" head_type: %d passthrough, %d gather, %d hull", n_pt, n_ga, n_hu)

n_persist = sum(len(p1) + len(p2) for _, p1, _, p2 in std_layers)
logger.info(
"Built model: d_model=%d, n_layers=%d, n_heads=%d, d_ffn=%d, erase=%s",
Expand Down Expand Up @@ -697,6 +727,14 @@ def W(t):
for h in range(H):
f.write(struct.pack("<i", model.head_tiebreak[li][h]))

has_head_type = hasattr(model, "head_type")
f.write(struct.pack("<i", 1 if has_head_type else 0))
if has_head_type:
H = model.attn[0].num_heads
for li in range(n_layers):
for h in range(H):
f.write(struct.pack("<i", model.head_type[li][h]))

logger.info("Saved weights to %s (%s bytes)", path, f"{os.path.getsize(path):,}")


Expand Down Expand Up @@ -764,6 +802,16 @@ def R(shape):
layer_tb = [struct.unpack("<i", f.read(4))[0] for _ in range(n_heads)]
model.head_tiebreak.append(layer_tb)

# Head type: 1 = passthrough, 0 = lookup
remaining = f.read(4)
if remaining and len(remaining) == 4:
has_head_type = struct.unpack("<i", remaining)[0]
if has_head_type:
model.head_type = []
for _ in range(n_layers):
layer_ht = [struct.unpack("<i", f.read(4))[0] for _ in range(n_heads)]
model.head_type.append(layer_ht)

logger.info(
"Loaded weights from %s (vocab=%d, d_model=%d, n_layers=%d, n_heads=%d, d_ffn=%d)",
path,
Expand Down
20 changes: 16 additions & 4 deletions transformer_vm/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def run_model_program(
_CPP_BINARY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "transformer")


def _find_libomp_prefix():
"""Find libomp install prefix on macOS (homebrew)."""
for prefix in ["/opt/homebrew/opt/libomp", "/usr/local/opt/libomp"]:
if os.path.isfile(os.path.join(prefix, "lib", "libomp.dylib")):
return prefix
return None


def _build_cpp_engine():
"""Build the C++ inference engine if not already built."""
binary = os.path.abspath(_CPP_BINARY)
Expand All @@ -99,6 +107,7 @@ def _build_cpp_engine():
logger.info("[engine] Compiling C++ inference engine...")
attn_dir = os.path.join(os.path.dirname(source), "..", "attention")
if platform.system() == "Darwin":
omp_prefix = _find_libomp_prefix()
cmd = [
"clang++",
"-std=c++17",
Expand All @@ -107,12 +116,15 @@ def _build_cpp_engine():
"Accelerate",
"-I",
attn_dir,
source,
"-o",
binary,
]
if omp_prefix:
cmd += ["-Xclang", "-fopenmp",
f"-I{omp_prefix}/include", f"-L{omp_prefix}/lib", "-lomp"]
cmd += [source, "-o", binary]
else:
cmd = ["g++", "-std=c++17", "-O3", "-I", attn_dir, source, "-o", binary]
cmd = ["g++", "-std=c++17", "-O3", "-fopenmp",
"-I", attn_dir, source, "-o", binary,
"-lopenblas"]
try:
subprocess.check_call(cmd)
logger.info("[engine] Built: %s", binary)
Expand Down