.json
+// and data/metadata.json, renders header + token-highlighted examples.
+
+const params = new URLSearchParams(location.search);
+const featureId = parseInt(params.get('feature') ?? '0', 10);
+
+function pad(id, width) {
+ return String(id).padStart(width, '0');
+}
+
+function widthFor(nFeatures) {
+ return Math.max(4, String(nFeatures - 1).length);
+}
+
+async function main() {
+ const md = await fetch('data/metadata.json').then(r => r.json());
+ const w = widthFor(md.n_features);
+ const f = await fetch(`data/features/${pad(featureId, w)}.json`).then(r => r.json());
+
+ document.title = `Feature #${f.feature_id} — MoLT dashboard`;
+ document.getElementById('ckpt').textContent =
+ (md.hf_filename || md.ckpt_path) + ` · layer ${md.layer}`;
+ document.getElementById('title').textContent = `Feature #${f.feature_id}`;
+
+ const meta = document.getElementById('meta');
+ meta.innerHTML = '';
+ addMeta(meta, `tier ${f.tier} · rank ${f.rank}`);
+ addMeta(meta, `activation rate: ${(f.activation_rate * 100).toFixed(3)}%`);
+ addMeta(meta, `max activation: ${f.max_activation.toFixed(3)}`);
+ addMeta(meta, `examples: ${f.examples.length}`);
+
+ // Prev/next nav.
+ const nav = document.createElement('span');
+ if (featureId > 0) {
+ nav.appendChild(link(`?feature=${featureId - 1}`, '← prev'));
+ nav.appendChild(document.createTextNode(' '));
+ }
+ if (featureId < md.n_features - 1) {
+ nav.appendChild(link(`?feature=${featureId + 1}`, 'next →'));
+ }
+ meta.appendChild(nav);
+
+ const root = document.getElementById('examples');
+ if (f.examples.length === 0) {
+ root.innerHTML = 'This transform did not fire on any token in the sampled corpus.
';
+ return;
+ }
+ for (const ex of f.examples) {
+ root.appendChild(renderExample(ex, f.max_activation));
+ }
+}
+
+function addMeta(parent, text) {
+ const s = document.createElement('span');
+ s.textContent = text;
+ parent.appendChild(s);
+}
+
+function link(href, text) {
+ const a = document.createElement('a');
+ a.href = href;
+ a.textContent = text;
+ return a;
+}
+
+function renderExample(ex, globalMax) {
+ const wrap = document.createElement('div');
+ wrap.className = 'example';
+
+ const header = document.createElement('div');
+ header.className = 'example-header';
+ header.textContent = `peak ${ex.peak_activation.toFixed(3)} · position ${ex.peak_token_pos} of ${ex.tokens.length}`;
+ wrap.appendChild(header);
+
+ const tokens = document.createElement('div');
+ tokens.className = 'tokens';
+ // Scale alpha by the per-feature global max so different sequences are
+ // visually comparable within the same feature.
+ const scale = globalMax > 0 ? globalMax : 1.0;
+ for (let i = 0; i < ex.tokens.length; i++) {
+ const span = document.createElement('span');
+ span.className = 'tok';
+ if (i === ex.peak_token_pos) span.classList.add('peak');
+ const a = ex.activations[i];
+ const alpha = Math.max(0, Math.min(1, a / scale));
+ span.style.backgroundColor = `rgba(220, 50, 50, ${alpha.toFixed(3)})`;
+ span.title = `act ${a.toFixed(4)}`;
+ // Preserve raw whitespace inside the span (CSS white-space: pre handles this).
+ span.textContent = ex.tokens[i] === '' ? ' ' : ex.tokens[i];
+ tokens.appendChild(span);
+ }
+ wrap.appendChild(tokens);
+
+ return wrap;
+}
+
+main().catch(err => {
+ document.getElementById('examples').innerHTML =
+ `Failed to load feature data: ${err.message}. ` +
+ `If you opened this with file://, serve the directory first ` +
+ `(e.g. python -m http.server).
`;
+});
diff --git a/crosslayer_transcoder/feature_dash/templates/index.html b/crosslayer_transcoder/feature_dash/templates/index.html
new file mode 100644
index 0000000..c00496f
--- /dev/null
+++ b/crosslayer_transcoder/feature_dash/templates/index.html
@@ -0,0 +1,37 @@
+
+
+
+
+ MoLT Feature Dashboard — index
+
+
+
+ MoLT feature dashboard
+ Loading…
+
+
+
+
+
+
+
+
+
+ | id |
+ tier |
+ rank |
+ rate |
+ max act |
+
+
+
+
+
+
+
diff --git a/crosslayer_transcoder/feature_dash/templates/index.js b/crosslayer_transcoder/feature_dash/templates/index.js
new file mode 100644
index 0000000..c2ee9bb
--- /dev/null
+++ b/crosslayer_transcoder/feature_dash/templates/index.js
@@ -0,0 +1,120 @@
+// Index page. One fetch of metadata.json populates the sortable table.
+
+let allRows = [];
+let sortKey = 'activation_rate';
+let sortDir = 'desc';
+let metadata = null;
+
+function pad(id, width) {
+ return String(id).padStart(width, '0');
+}
+
+async function main() {
+ metadata = await fetch('data/metadata.json').then(r => r.json());
+
+ document.getElementById('meta').textContent =
+ `${metadata.n_features} transforms · ` +
+ `${(metadata.hf_filename || metadata.ckpt_path)} · layer ${metadata.layer} · ` +
+ `${metadata.n_tokens_collected.toLocaleString()} tokens from ${metadata.dashboard_dataset}`;
+
+ // Build the row data.
+ for (let i = 0; i < metadata.n_features; i++) {
+ allRows.push({
+ feature_id: i,
+ tier: metadata.feature_tier[i],
+ rank: metadata.feature_rank[i],
+ activation_rate: metadata.feature_activation_rate[i],
+ max_activation: metadata.feature_max_activation[i],
+ });
+ }
+
+ // Populate tier filter.
+ const uniqueTiers = [...new Set(metadata.feature_tier)].sort((a, b) => a - b);
+ const tierSel = document.getElementById('tier-filter');
+ for (const t of uniqueTiers) {
+ const opt = document.createElement('option');
+ opt.value = String(t);
+ opt.textContent = `tier ${t} (rank ${metadata.ranks[t]})`;
+ tierSel.appendChild(opt);
+ }
+
+ // Wire interactions.
+ document.querySelectorAll('th[data-key]').forEach(th => {
+ th.addEventListener('click', () => {
+ const k = th.dataset.key;
+ if (sortKey === k) {
+ sortDir = sortDir === 'desc' ? 'asc' : 'desc';
+ } else {
+ sortKey = k;
+ sortDir = (k === 'feature_id' || k === 'tier') ? 'asc' : 'desc';
+ }
+ render();
+ });
+ });
+ ['change', 'input'].forEach(evt => {
+ tierSel.addEventListener(evt, render);
+ document.getElementById('min-rate').addEventListener(evt, render);
+ document.getElementById('max-rate').addEventListener(evt, render);
+ });
+
+ render();
+}
+
+function render() {
+ const tier = document.getElementById('tier-filter').value;
+ const minRate = parseFloat(document.getElementById('min-rate').value) || 0;
+ const maxRate = parseFloat(document.getElementById('max-rate').value);
+ const maxRateOk = Number.isFinite(maxRate) ? maxRate : 1;
+
+ const filtered = allRows.filter(r =>
+ (tier === '' || String(r.tier) === tier) &&
+ r.activation_rate >= minRate &&
+ r.activation_rate <= maxRateOk
+ );
+
+ filtered.sort((a, b) => {
+ const av = a[sortKey], bv = b[sortKey];
+ if (av < bv) return sortDir === 'asc' ? -1 : 1;
+ if (av > bv) return sortDir === 'asc' ? 1 : -1;
+ return a.feature_id - b.feature_id;
+ });
+
+ document.querySelectorAll('th[data-key]').forEach(th => {
+ th.classList.remove('sorted-asc', 'sorted-desc');
+ if (th.dataset.key === sortKey) {
+ th.classList.add(sortDir === 'asc' ? 'sorted-asc' : 'sorted-desc');
+ }
+ });
+
+ document.getElementById('count').textContent =
+ `${filtered.length.toLocaleString()} of ${allRows.length.toLocaleString()} shown`;
+
+ const tbody = document.querySelector('#features tbody');
+ // Render up to 2000 rows; the user can filter further if they need more.
+ const cap = 2000;
+ const rows = filtered.slice(0, cap);
+ const html = rows.map(r => {
+ const isDead = r.activation_rate === 0;
+ return `
+ | #${r.feature_id} |
+ ${r.tier} |
+ ${r.rank} |
+ ${(r.activation_rate * 100).toFixed(3)}% |
+ ${r.max_activation.toFixed(3)} |
+
`;
+ }).join('');
+ tbody.innerHTML = html;
+ if (filtered.length > cap) {
+ tbody.insertAdjacentHTML(
+ 'beforeend',
+ `| ${filtered.length - cap} more rows hidden — narrow the filter |
`
+ );
+ }
+}
+
+main().catch(err => {
+ document.getElementById('meta').innerHTML =
+ `Failed to load metadata.json: ${err.message}. ` +
+ `If you opened this with file://, serve the directory first ` +
+ `(e.g. python -m http.server).`;
+});
diff --git a/crosslayer_transcoder/model/__init__.py b/crosslayer_transcoder/model/__init__.py
index e4f0aae..6a88618 100644
--- a/crosslayer_transcoder/model/__init__.py
+++ b/crosslayer_transcoder/model/__init__.py
@@ -4,11 +4,13 @@
from .clt import CrossLayerTranscoder
from .clt_lightning import CrossLayerTranscoderModule
+from .molt import Molt
from .topk import BatchTopK, PerLayerBatchTopK, PerLayerTopK
__all__ = [
"CrossLayerTranscoder",
"CrossLayerTranscoderModule",
+ "Molt",
"BatchTopK",
"PerLayerTopK",
"PerLayerBatchTopK",
diff --git a/crosslayer_transcoder/model/clt_lightning.py b/crosslayer_transcoder/model/clt_lightning.py
index 8e4a020..2256028 100644
--- a/crosslayer_transcoder/model/clt_lightning.py
+++ b/crosslayer_transcoder/model/clt_lightning.py
@@ -2,7 +2,7 @@
import os
import subprocess
import time
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
import lightning as L
import psutil
@@ -23,6 +23,7 @@
Decoder,
)
from crosslayer_transcoder.model.jumprelu import JumpReLU
+from crosslayer_transcoder.model.molt import Molt
from crosslayer_transcoder.model.topk import BatchTopK
@@ -30,7 +31,7 @@ class CrossLayerTranscoderModule(L.LightningModule):
def __init__(
self,
# Pre-constructed modules
- model: CrossLayerTranscoder,
+ model: Union[CrossLayerTranscoder, Molt],
replacement_model: Optional[ReplacementModelAccuracy] = None,
dead_features: Optional[DeadFeatures] = None,
# Training parameters
@@ -85,17 +86,23 @@ def __init__(
self.beta2 = beta2
self.log_metrics_every = log_metrics_every
- assert self.model.encoder.n_layers == self.model.decoder.n_layers, (
- "Encoder and decoder must have the same number of layers"
- )
+ if isinstance(self.model, Molt):
+ self.register_buffer(
+ "last_active",
+ torch.zeros((self.model.n_features,), dtype=torch.long),
+ )
+ else:
+ assert self.model.encoder.n_layers == self.model.decoder.n_layers, (
+ "Encoder and decoder must have the same number of layers"
+ )
- self.register_buffer(
- "last_active",
- torch.zeros(
- (self.model.encoder.n_layers, self.model.encoder.d_features),
- dtype=torch.long,
- ),
- )
+ self.register_buffer(
+ "last_active",
+ torch.zeros(
+ (self.model.encoder.n_layers, self.model.encoder.d_features),
+ dtype=torch.long,
+ ),
+ )
def configure_model(self):
# Apply compilation if requested
@@ -565,3 +572,68 @@ def training_step(self, batch, batch_idx):
torch.cuda.memory._record_memory_history(enabled=None)
exit()
return loss
+
+
+class MoltModule(CrossLayerTranscoderModule):
+ def __init__(
+ self,
+ lambda_sparsity: float = 0.0002,
+ c_sparsity: float = 0.1,
+ use_tanh: bool = True,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self._lambda = lambda_sparsity
+ self.c = c_sparsity
+ self.use_tanh = use_tanh
+
+ def current_sparsity_penalty(self):
+ n_steps = self.trainer.max_steps
+ current_step = (
+ self.global_step
+ ) # use global step instead of batch idx to work with gradient accumulation
+ cur_lambda = self._lambda * (current_step / n_steps)
+ self.log("training/sparsity_penalty", cur_lambda)
+ return cur_lambda
+
+ def forward(self, batch, layer):
+ return self.model.forward(batch, layer)
+
+ def training_step(self, batch, batch_idx):
+ if batch_idx == 0:
+ self.model.initialize_standardizers(batch)
+ self.log("model/d_latents", self.model.d_latents)
+ self.log("model/n_features", self.model.n_features)
+
+ layer = 8
+
+ # Forward pass
+ resid, mlp_out = batch[:, 0], batch[:, 1]
+ resid = resid[:, layer]
+ mlp_out = mlp_out[:, layer]
+ gate, recons_norm, recons = self.model.forward(resid, layer)
+
+ self.update_dead_features(gate)
+ # Compute MSE loss
+ mse = (recons_norm - self.model.output_standardizer.standardize(mlp_out, layer)) ** 2
+
+ # Compute Sparsity Loss
+ norms = self.model.transform_norm()
+ weighted_norms = norms * gate
+ self.log("model/weighted_norms_mean", weighted_norms.detach().mean().cpu())
+
+ if self.use_tanh:
+ weighted_norms = torch.tanh(weighted_norms * self.c)
+ sparsity = self.current_sparsity_penalty() * weighted_norms.sum(dim=-1).mean()
+ self.log("training/sparsity_loss", sparsity)
+ self.log("L0", (gate > 0.0).float().sum() / gate.shape[0])
+
+ loss = mse.mean() + sparsity
+ self.log("training/mse", mse.mean())
+ self.log("training/loss", loss)
+
+ if batch_idx % self.log_metrics_every == 0:
+ pass
+
+ return loss
diff --git a/crosslayer_transcoder/model/jumprelu.py b/crosslayer_transcoder/model/jumprelu.py
index d21b7ec..c977830 100644
--- a/crosslayer_transcoder/model/jumprelu.py
+++ b/crosslayer_transcoder/model/jumprelu.py
@@ -52,7 +52,8 @@ def backward(ctx, grad_output):
class JumpReLU(SerializableModule):
def __init__(self, theta=0.0, bandwidth=1.0, n_layers=12, d_features=768 * 8):
super().__init__()
- self.theta = nn.Parameter(torch.full((1, n_layers, d_features), theta))
+ shape = (1, n_layers, d_features) if n_layers > 1 else (1, d_features)
+ self.theta = nn.Parameter(torch.full(shape, theta))
self.register_buffer("bandwidth", torch.tensor(bandwidth))
self._init_theta = theta
self.n_layers = n_layers
diff --git a/crosslayer_transcoder/model/molt.py b/crosslayer_transcoder/model/molt.py
new file mode 100644
index 0000000..afb3fd3
--- /dev/null
+++ b/crosslayer_transcoder/model/molt.py
@@ -0,0 +1,93 @@
+import einops
+import torch
+import torch.nn as nn
+from jaxtyping import Float
+
+
+class Molt(nn.Module):
+ def __init__(
+ self,
+ d_acts: int,
+ N: int,
+ nonlinearity: nn.Module,
+ input_standardizer: nn.Module,
+ output_standardizer: nn.Module,
+ ranks: list[int] = [512, 256, 128, 64, 32],
+ ):
+ super().__init__()
+
+ self.d_acts = d_acts
+ self.nonlinearity = nonlinearity
+ self.input_standardizer = input_standardizer
+ self.output_standardizer = output_standardizer
+ Us = []
+ Vs = []
+ rank_multiplier = 1
+ n_features = 0
+ d_latents = 0
+ for rank in ranks:
+ Us.append(nn.Parameter(torch.empty(N * rank_multiplier, rank, d_acts)))
+ Vs.append(nn.Parameter(torch.empty(N * rank_multiplier, d_acts, rank)))
+ n_features += N * rank_multiplier
+ d_latents += N * rank_multiplier * rank
+ rank_multiplier *= 2
+ self.n_features = n_features
+ self.e = nn.Linear(d_acts, n_features)
+ self.Us = nn.ParameterList(Us)
+ self.Vs = nn.ParameterList(Vs)
+
+ print(f"d_latents (transcoder equivalent): {d_latents}")
+ self.d_latents = d_latents
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for U in self.Us:
+ nn.init.xavier_uniform_(U)
+ for V in self.Vs:
+ nn.init.xavier_uniform_(V)
+
+ def transform_norm(self):
+ norms = []
+ for U, V in zip(self.Us, self.Vs):
+ uv = einops.einsum(
+ U,
+ V,
+ "n_transforms d_transform d_acts_out, n_transforms d_acts_in d_transform -> n_transforms d_acts_in d_acts_out",
+ )
+ norms.append(torch.norm(uv, dim=(1, 2)))
+ return torch.cat(norms, dim=0)
+
+ def forward(
+ self, acts: Float[torch.Tensor, "batch_size d_acts"], layer: int
+ ) -> Float[torch.Tensor, "batch_size d_acts"]:
+ acts = self.input_standardizer(acts, layer)
+ pre_actvs = self.e(acts)
+ gate = self.nonlinearity(pre_actvs) # (batch, n_transforms)
+
+ raw_recons = []
+ for U, V in zip(self.Us, self.Vs):
+ latents = einops.einsum(
+ acts,
+ V,
+ "batch d_acts, n_transforms d_acts d_transform -> batch n_transforms d_transform",
+ )
+ raw_recons.append(
+ einops.einsum(
+ latents,
+ U,
+ "batch n_transforms d_transform, n_transforms d_transform d_acts -> batch n_transforms d_acts",
+ )
+ )
+
+ raw_recons = torch.cat(raw_recons, dim=1)
+
+ weighted_recons = gate.unsqueeze(-1) * raw_recons
+ recons_norm = weighted_recons.sum(dim=1)
+
+ recons = self.output_standardizer(recons_norm, layer)
+ return gate, recons_norm, recons
+
+ def initialize_standardizers(self, batch: Float[torch.Tensor, "batch_size io n_layers d_acts"]):
+ self.input_standardizer.initialize_from_batch(batch)
+ self.output_standardizer.initialize_from_batch(batch)
diff --git a/pyproject.toml b/pyproject.toml
index 849e873..52e8de7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,6 +2,9 @@
requires = ["hatchling"]
build-backend = "hatchling.build"
+[tool.hatch.build.targets.wheel.force-include]
+"crosslayer_transcoder/feature_dash/templates" = "crosslayer_transcoder/feature_dash/templates"
+
[project]
name = "crosslayer-transcoder"
version = "0.1.0"
diff --git a/tests/test_feature_dash_bundle.py b/tests/test_feature_dash_bundle.py
new file mode 100644
index 0000000..d680bee
--- /dev/null
+++ b/tests/test_feature_dash_bundle.py
@@ -0,0 +1,250 @@
+"""Tests for the single-file `bundle.html` builder.
+
+Two paths to the bundle:
+ - `make_bundle_from_disk(out_dir)`: read an existing dump, write bundle.html.
+ - `make_bundle(collector, meta, tokenizer, out_path, ...)`: build directly
+ from in-memory state, no disk dump needed.
+
+Both should produce a self-contained file (no `fetch(` calls, all data inlined)
+that contains every feature's payload.
+"""
+
+from __future__ import annotations
+
+import json
+import re
+from pathlib import Path
+
+import pytest
+import torch
+
+from crosslayer_transcoder.feature_dash.bundle import (
+ default_bundle_filename,
+ make_bundle,
+ make_bundle_from_disk,
+)
+from crosslayer_transcoder.feature_dash.collect import GateCollector
+from crosslayer_transcoder.feature_dash.dump import dump_dashboard
+from crosslayer_transcoder.feature_dash.load import MoltCheckpointMetadata
+
+
+class _StubTokenizer:
+ def decode(self, ids):
+ return f" tok{ids[0]}"
+
+
+def _populate_collector(F: int, K: int, T: int) -> GateCollector:
+ coll = GateCollector(n_features=F, top_k=K, seq_len=T)
+ tok = torch.arange(2 * T, dtype=torch.long).reshape(2, T)
+ g = torch.zeros(2, T, F)
+ g[0, 1, 0] = 1.0
+ g[1, 3, 0] = 2.5
+ g[0, 0, 1] = 0.5
+ coll.update(tok, g)
+ return coll
+
+
+def _meta(F: int, ranks: list[int], N: int) -> MoltCheckpointMetadata:
+ feature_tier: list[int] = []
+ feature_rank: list[int] = []
+ for t, r in enumerate(ranks):
+ n_in_tier = N * (2**t)
+ feature_tier.extend([t] * n_in_tier)
+ feature_rank.extend([r] * n_in_tier)
+ return MoltCheckpointMetadata(
+ ckpt_path="dummy.ckpt",
+ d_acts=8,
+ n_features=F,
+ n_layers=2,
+ ranks=ranks,
+ N=N,
+ feature_tier=feature_tier,
+ feature_rank=feature_rank,
+ base_model_name="stub",
+ training_dataset="stub-ds",
+ global_step=123,
+ epoch=0,
+ )
+
+
+def _extract_inlined_json(html: str, script_id: str) -> dict:
+ """Pull a `` payload."""
+ m = re.search(
+ rf'',
+ html,
+ flags=re.DOTALL,
+ )
+ assert m, f"missing inlined script #{script_id}"
+ raw = m.group(1)
+ # Reverse the `<` hardening before parsing.
+ raw = raw.replace(r"<", "<")
+ return json.loads(raw)
+
+
+def test_bundle_contains_no_fetch_calls(tmp_path: Path):
+ """Bundle must work from file:// — no network/disk fetches allowed."""
+ F, K, T, N = 3, 2, 5, 1
+ coll = _populate_collector(F, K, T)
+ meta = _meta(F, [4, 2], N)
+
+ out = make_bundle(
+ collector=coll,
+ meta=meta,
+ tokenizer=_StubTokenizer(),
+ out_path=tmp_path / "bundle.html",
+ layer=8,
+ window=2,
+ )
+ text = out.read_text()
+ # No fetch(... call anywhere — neither real nor templated.
+ assert "fetch(" not in text
+
+
+def test_bundle_inlines_metadata_and_all_features(tmp_path: Path):
+ F, K, T, N = 3, 2, 5, 1
+ coll = _populate_collector(F, K, T)
+ meta = _meta(F, [4, 2], N)
+
+ bundle_path = make_bundle(
+ collector=coll, meta=meta, tokenizer=_StubTokenizer(),
+ out_path=tmp_path, layer=8, window=2,
+ )
+ # When out_path is a directory, the bundle is named after the checkpoint.
+ # `_meta()` sets ckpt_path="dummy.ckpt" and no hf_filename, so stem="dummy".
+ assert bundle_path == tmp_path / "bundle_dummy.html"
+
+ html = bundle_path.read_text()
+ md = _extract_inlined_json(html, "metadata")
+ feats = _extract_inlined_json(html, "features")
+
+ assert md["n_features"] == F
+ assert md["layer"] == 8
+ assert md["feature_tier"] == [0, 1, 1]
+ assert md["feature_rank"] == [4, 2, 2]
+ assert len(md["feature_activation_rate"]) == F
+ assert len(md["feature_max_activation"]) == F
+
+ assert set(feats.keys()) == {"0", "1", "2"}
+ for k, v in feats.items():
+ assert v["feature_id"] == int(k)
+ assert {"tier", "rank", "activation_rate", "max_activation", "examples"} <= set(v.keys())
+
+ # Feature 0 has two real examples and feature 2 is dead.
+ assert len(feats["0"]["examples"]) == 2
+ assert feats["2"]["examples"] == []
+
+
+def test_bundle_css_is_inlined(tmp_path: Path):
+ F, K, T, N = 3, 2, 5, 1
+ coll = _populate_collector(F, K, T)
+ meta = _meta(F, [4, 2], N)
+ out = make_bundle(
+ collector=coll, meta=meta, tokenizer=_StubTokenizer(),
+ out_path=tmp_path / "bundle.html", layer=8, window=2,
+ )
+ html = out.read_text()
+ # No external stylesheet link, and a non-empty inline " in html
+ assert "table.features" in html # one of our CSS rules
+
+
+def test_bundle_from_disk_reads_dumped_dashboard(tmp_path: Path):
+ F, K, T, N = 3, 2, 5, 1
+ coll = _populate_collector(F, K, T)
+ meta = _meta(F, [4, 2], N)
+
+ dump_dashboard(
+ collector=coll, meta=meta, tokenizer=_StubTokenizer(),
+ out_dir=tmp_path, layer=8, window=2, copy_assets=False,
+ )
+ bundle_path = make_bundle_from_disk(tmp_path)
+
+ # ckpt_path="dummy.ckpt" -> bundle_dummy.html
+ assert bundle_path == tmp_path / "bundle_dummy.html"
+ html = bundle_path.read_text()
+ md = _extract_inlined_json(html, "metadata")
+ feats = _extract_inlined_json(html, "features")
+ assert md["n_features"] == F
+ assert len(feats) == F
+
+
+def test_bundle_from_disk_errors_when_data_missing(tmp_path: Path):
+ with pytest.raises(FileNotFoundError):
+ make_bundle_from_disk(tmp_path)
+
+
+def test_default_bundle_filename_prefers_hf_filename():
+ F, N = 3, 1
+ meta = _meta(F, [4, 2], N)
+ meta.hf_filename = "gpt2-molt-lam-0_00015-50M.ckpt"
+ meta.ckpt_path = "/some/random/local/path.ckpt"
+ assert default_bundle_filename(meta) == "bundle_gpt2-molt-lam-0_00015-50M.html"
+
+
+def test_default_bundle_filename_falls_back_to_ckpt_path():
+ F, N = 3, 1
+ meta = _meta(F, [4, 2], N)
+ meta.hf_filename = None
+ meta.ckpt_path = "/runs/exp42/checkpoint-final.ckpt"
+ assert default_bundle_filename(meta) == "bundle_checkpoint-final.html"
+
+
+def test_default_bundle_filename_handles_unsafe_chars():
+ F, N = 3, 1
+ meta = _meta(F, [4, 2], N)
+ meta.hf_filename = "weird/name with spaces.ckpt"
+ # The stem is "name with spaces" (Path.stem already drops the directory).
+ # Spaces are sanitised to '-'.
+ assert default_bundle_filename(meta) == "bundle_name-with-spaces.html"
+
+
+def test_make_bundle_explicit_filepath_is_respected(tmp_path: Path):
+ F, K, T, N = 3, 2, 5, 1
+ coll = _populate_collector(F, K, T)
+ meta = _meta(F, [4, 2], N)
+
+ explicit = tmp_path / "my-custom-name.html"
+ out = make_bundle(
+ collector=coll, meta=meta, tokenizer=_StubTokenizer(),
+ out_path=explicit, layer=8, window=2,
+ )
+ # When the user passes a file path, we don't override it.
+ assert out == explicit
+
+
+def test_bundle_handles_script_tag_in_token_text(tmp_path: Path):
+ """A token string containing `` must not break the bundle.
+
+ The sanitiser replaces `<` with `\\u003c`; the JS in the bundle parses the
+ JSON via `JSON.parse`, which decodes the unicode escape back. We just
+ verify here that the raw HTML doesn't contain a literal `` inside
+ the data payload.
+ """
+
+ class _NastyTokenizer:
+ def decode(self, ids):
+ return "BAD"
+
+ F, K, T, N = 3, 2, 5, 1
+ coll = _populate_collector(F, K, T)
+ meta = _meta(F, [4, 2], N)
+
+ out = make_bundle(
+ collector=coll, meta=meta, tokenizer=_NastyTokenizer(),
+ out_path=tmp_path / "bundle.html", layer=8, window=2,
+ )
+ html = out.read_text()
+ # Find the data scripts and check their content is sanitised.
+ m = re.search(
+ r'',
+ html, flags=re.DOTALL,
+ )
+ assert m
+ payload = m.group(1)
+ # The sanitised payload must not contain a literal "" sequence.
+ assert "" not in payload, payload
+ # And the bundle as a whole still has exactly the two data scripts and
+ # one logic script — no premature closure.
+ assert html.count('