Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
e1a833c
tested converstion to "forward" decoder
jiito Sep 25, 2025
8774ddd
successful integration test
jiito Sep 26, 2025
cee0d02
rm imports
jiito Sep 30, 2025
211918f
add CT callback
jiito Sep 30, 2025
cd8ebda
initial working modal training setup
jiito Sep 13, 2025
e682bd2
trust remote code
jiito Sep 16, 2025
67eb715
add HF_HOME env vars
jiito Sep 16, 2025
0234e73
working modal code w/o cleanup
jiito Sep 17, 2025
6125bbc
reinit logs incase of modal preemption
jiito Sep 17, 2025
e07e063
replacment model config
jiito Sep 17, 2025
8a51670
basic script to test mem usage on modal
jiito Sep 19, 2025
36d044f
refactor: greatly simplify modal wrapper
jiito Sep 30, 2025
6d069a2
restore formatting issues
jiito Sep 30, 2025
f6bd08e
remove unused import
jiito Oct 1, 2025
7c0a36a
uv sync
jiito Oct 1, 2025
6a1f215
add example config with cb, save model intermediately
jiito Oct 1, 2025
d0e70be
add TODO
jiito Oct 1, 2025
6f1f45b
add non-linearity
jiito Oct 1, 2025
e5844e7
refactor: move to test and util file
jiito Oct 1, 2025
7b383c3
add upload to hf callback
jiito Oct 2, 2025
2617e58
debug config
jiito Oct 13, 2025
f03aab0
refactor: change to abc converters
jiito Oct 13, 2025
af7fa38
update test
jiito Oct 13, 2025
a0feda0
feat: first pass at standardization folding
jiito Oct 13, 2025
12ab5c8
debug+circuit-tracer config
jiito Oct 14, 2025
8fc6284
fix: shapes of standardization
jiito Oct 14, 2025
c64955f
update config
jiito Oct 15, 2025
890a60c
initial failing test case with float
jiito Oct 17, 2025
459acd7
fix: numerical stability
jiito Oct 17, 2025
74a80ea
add const, rm comments
jiito Oct 17, 2025
3a76fb7
refactor: use standardizer method in conversion
jiito Oct 17, 2025
c19536b
cleanup comments
jiito Oct 17, 2025
c97a240
rm ckpt loading test
jiito Nov 4, 2025
663f5d2
remove debug configs
jiito Nov 6, 2025
9d8ea2b
remove main.py changes
jiito Nov 6, 2025
a3e6045
remove modal changes
jiito Nov 6, 2025
06c9f4c
restore lightning file
jiito Nov 6, 2025
9158094
make conversion callback more general
jiito Nov 6, 2025
26c8d4f
add circuit-tracer debug config
jiito Nov 6, 2025
cfb1631
cleanup unused methods
jiito Nov 6, 2025
9d1594a
fix: circuit tracer installation
jiito Nov 6, 2025
1ee4eb4
downgrade datasets
jiito Nov 6, 2025
61801e5
fix callback params
jiito Nov 6, 2025
1d3ffe1
Merge branch 'master' of github.com:Goreg12345/crosslayer-transcoder …
jiito Nov 8, 2025
0e6d1c5
rm math sanity check test
jiito Nov 8, 2025
ba74a53
change to train batch end
jiito Nov 8, 2025
cbd9d3a
cleanup + todos
jiito Nov 8, 2025
9bf31ce
basic checkpoint loading + test
jiito Nov 8, 2025
476f9c6
debug config file
jiito Nov 8, 2025
8c1fcb4
revert formatting changes in clt_lightning
jiito Nov 8, 2025
ce034d7
revert comment
jiito Nov 8, 2025
fb768e9
rename to clearer method
jiito Nov 8, 2025
d57b766
add file checks
jiito Nov 8, 2025
c347ba2
save each decoder
jiito Nov 8, 2025
11d92ee
rm unused import
jiito Nov 8, 2025
705c4e8
move fixture out
jiito Nov 8, 2025
72d1e1b
rm prints
jiito Nov 8, 2025
bfce3bb
fix: non-lin, rm hf upload
jiito Nov 8, 2025
e13e04d
add basic test for dowloading clt ckpt
jiito Nov 13, 2025
4271019
basic checkpoint loading + test
jiito Nov 13, 2025
62e2fd9
add topk debug config
jiito Nov 13, 2025
d94a681
init replacement score notebook
jiito Nov 13, 2025
2be6460
load from config
jiito Nov 13, 2025
42e3ebf
init sanity check
jiito Nov 15, 2025
4631d61
add sanity checks to replacement score
jiito Nov 16, 2025
2f4a19e
update debug config
jiito Nov 16, 2025
e00d31b
sanity check
jiito Nov 16, 2025
e51ba77
add encode and encode folded
jiito Nov 16, 2025
5d99ebf
add dtype
jiito Nov 16, 2025
26889e0
cleanup tests
jiito Nov 16, 2025
b5c2331
add progress bar
jiito Nov 16, 2025
5d6ac4a
run notebook sanity checks
jiito Nov 16, 2025
a872e3f
fix: load from pretrained and update deps
jiito Nov 17, 2025
fec1d0e
save titles
jiito Nov 17, 2025
d494578
restore util file
jiito Nov 17, 2025
481d4e8
Merge branch 'conversion-to-circuit-tracer' of github.com:jiito/cross…
jiito Nov 17, 2025
0b2a114
restore formatting
jiito Nov 17, 2025
6a03202
cleanup tests and unused methods
jiito Nov 18, 2025
8652e3a
test: config of hook points for callback
jiito Nov 18, 2025
718586a
use partial
jiito Nov 18, 2025
81298bb
refactor: rename and delete indirection
jiito Nov 18, 2025
0c66441
refactor: change to protocol over ABC
jiito Nov 19, 2025
3550185
add safety to module builder
jiito Nov 19, 2025
a28f4a1
fix: circuit tracer converter callback
jiito Nov 22, 2025
2746ea8
fix: remove direct typehints and add comments
jiito Dec 1, 2025
0100c1d
borken for comparison
jiito Dec 3, 2025
1b4e8d6
fix: protocol typing
jiito Dec 4, 2025
c87117a
add runtime_checkable to protocol
jiito Dec 8, 2025
30c067a
Merge branch 'master' of github.com:Goreg12345/crosslayer-transcoder …
jiito Jan 15, 2026
c9b0fc6
fix: move tests
jiito Jan 15, 2026
40d1291
add relu error test
jiito Jan 15, 2026
3e57bb1
fix: uv sync
jiito Jan 15, 2026
9349b54
fix: workflow following uv docs
jiito Jan 15, 2026
68165fe
fix: stopgap ignore dataset
jiito Jan 15, 2026
36e6ba0
Merge branch 'master' of github.com:Goreg12345/crosslayer-transcoder …
jiito Jan 27, 2026
ab7a0ec
small formatting fixes
jiito Jan 28, 2026
cc6d204
Merge branch 'fix/pr-tests-workflow' of github.com:jiito/crosslayer-t…
jiito Jan 28, 2026
766e9ca
wip: refactor from_pretratined to take hf_urls
jiito Feb 5, 2026
625f5b2
refactor to avoid lightning module
jiito Feb 5, 2026
be14218
feat: support PerLayerTopK activation function
jiito Feb 5, 2026
f7d2308
fix: tests
jiito Feb 5, 2026
1893c11
use IOI prompt
jiito Feb 5, 2026
542149c
cleanup test
jiito Feb 5, 2026
b3a3af0
initial refactor of model saving
jiito Feb 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/pr-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ jobs:
python-version: "3.11"

- name: Install uv
run: pip install uv
uses: astral-sh/setup-uv@v7

- name: Install dependencies
run: uv install
- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Run pytest
run: uv run pytest
- name: Run tests
run: uv run pytest --ignore=tests/test_text_dataset.py
172 changes: 172 additions & 0 deletions config/circuit-tracer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Default configuration for CrossLayer Transcoder training
# This file uses Lightning CLI's automatic class construction

seed_everything: 42

trainer:
# max_steps is number of gradient updates. If using gradient accumulation, this is not the number of batches.
max_steps: 100_000
val_check_interval: 1_000
limit_val_batches: 1
enable_checkpointing: false # We use custom end-of-training checkpoint
num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized
precision: "16-mixed"
accelerator: "gpu"
devices: [0] # [0] means cuda:0
accumulate_grad_batches: 1
logger: # WandB logger is recommended but other loggers are supported as well
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: "clt"
name: "circuit-tracer"
save_dir: "./wandb"
callbacks:
- class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback
init_args:
checkpoint_dir: "checkpoints"
- class_path: crosslayer_transcoder.utils.callbacks.ModelConversionCallback
init_args:
converter:
class_path: crosslayer_transcoder.utils.model_converters.circuit_tracer.CircuitTracerConverter
init_args:
save_dir: "circuit-tracer"
feature_input_hook: "hook_resid_mid"
feature_output_hook: "hook_mlp_out"
on_events: ["on_train_batch_end"]

model:
class_path: crosslayer_transcoder.model.clt_lightning.JumpReLUCrossLayerTranscoderModule
init_args:
model:
class_path: crosslayer_transcoder.model.clt.CrossLayerTranscoder
init_args:
encoder:
class_path: crosslayer_transcoder.model.clt.Encoder
init_args:
d_acts: 768
d_features: 10_000
n_layers: 12

decoder:
class_path: crosslayer_transcoder.model.clt.CrosslayerDecoder
init_args:
d_acts: 768
d_features: 10_000
n_layers: 12

nonlinearity:
class_path: crosslayer_transcoder.model.jumprelu.JumpReLU
init_args:
theta: 0.03
bandwidth: 0.01
n_layers: 12
d_features: 10_000

input_standardizer:
class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer
init_args:
n_layers: 12
activation_dim: 768

output_standardizer:
class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer
init_args:
n_layers: 12
activation_dim: 768

# Pre-constructed replacement model
replacement_model:
class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy
init_args:
model_name: "openai-community/gpt2"
device_map: "cuda:0" # should match trainer.devices
loader_batch_size: 2

# Pre-constructed dead features metric
dead_features:
class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures
init_args:
n_features: 10_000
n_layers: 12
return_per_layer: true
return_log_freqs: true
return_neuron_indices: true


# Training parameters
learning_rate: 3e-4
compile: true # if using torch.compile
lr_decay_step: 16_000 # lr is scaled by lr_decay_factor after this many steps
lr_decay_factor: 0.1

lambda_sparsity: 0.0007 # sparsity loss weight
c_sparsity: 1 # sparsity loss coefficient
use_tanh: true # use tanh nonlinearity in the JumpReLU
pre_actv_loss: 1e-6 # pre-activation loss weight

# Dead features computation settings
compute_dead_features: true
compute_dead_features_every: 500

data:
class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule
init_args:
# Buffer settings
buffer_size: 500_000 # number of activations to store in the buffer
n_in_out: 2 # number of input and output layers
n_layers: 12 # number of layers in the model
activation_dim: 768 # dimension of the activations
dtype: "float16" # dtype of the activations
max_batch_size: 8000 # maximum batch size for the data loader

# Model settings for activation generation
model_name: "openai-community/gpt2"
model_dtype: "float32"

# Dataset settings
dataset_name: "Skylion007/openwebtext"
dataset_split: "train"
max_sequence_length: 1024

# Generation settings
generation_batch_size: 10
refresh_interval: 0.1 # time (s) between shell logs updates

# Memory settings
shared_memory_name: "activation_buffer"
timeout_seconds: 30

# File paths
init_file: null # path to file with shuffled activations to initialize the buffer fast
# if null, activations are generated and training starts when the buffer is at least minimum_fill_threshold full

# DataLoader settings
batch_size: 1000
num_workers: 10
prefetch_factor: 2
shuffle: true
persistent_workers: true
pin_memory: true

minimum_fill_threshold: 0.2 # Only provide activations when buffer is at least 20% full
# to maintain sufficient shuffling

use_shared_memory: true

# Device configuration
device_map: "cuda:0" # "cpu", "auto", "cuda:0", "cuda:0,1,2,3"
deployment_policy: "gpu_only" # "cpu_only", "gpu_only", or "dynamic"
# dynamic will use CPU and only GPU if the buffer is almost empty to refill fast. Use this if you use a single GPU and have a beefy CPU.

# WandB logging configuration for data generation
wandb_logging:
enabled: true # Enable WandB logging for data generation
project: "clt" # WandB project (should match trainer logger)
group: null # Group name (null = auto-generated from training run)
run_name: "data-generator-jumprelu" # Run name suffix
tags: ["data-generation"] # Tags for the data generation run
save_dir: "./wandb" # Directory for WandB files
log_interval: 5.0 # Logging interval in seconds
offline: true # Offline mode for WandB logging

ckpt_path: null
81 changes: 71 additions & 10 deletions crosslayer_transcoder/model/clt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import einops
import torch
import torch.nn as nn
import yaml
Expand Down Expand Up @@ -194,6 +195,17 @@ def to_config(self) -> Dict[str, Any]:
},
}

def to_circuit_tracer(self):
W = einops.rearrange(
self.get_parameter("W"),
"n_layers d_acts d_features -> n_layers d_features d_acts",
).contiguous()
b = self.get_parameter("b")
return {
"W": W,
"b": b,
}


class Decoder(SerializableModule):
def __init__(self, d_acts: int, d_features: int, n_layers: int):
Expand All @@ -202,16 +214,16 @@ def __init__(self, d_acts: int, d_features: int, n_layers: int):
self.d_features = d_features
self.n_layers = n_layers
self.register_parameter(
f"W", nn.Parameter(torch.empty((n_layers, d_features, d_acts)))
"W", nn.Parameter(torch.empty((n_layers, d_features, d_acts)))
)
self.register_parameter(f"b", nn.Parameter(torch.empty((n_layers, d_acts))))
self.register_parameter("b", nn.Parameter(torch.empty((n_layers, d_acts))))
self._is_folded = False
self.reset_parameters()

def reset_parameters(self):
dec_uniform_thresh = 1 / ((self.d_acts * self.n_layers) ** 0.5)
self.get_parameter(f"W").data.uniform_(-dec_uniform_thresh, dec_uniform_thresh)
self.get_parameter(f"b").data.zero_()
self.get_parameter("W").data.uniform_(-dec_uniform_thresh, dec_uniform_thresh)
self.get_parameter("b").data.zero_()

@torch.no_grad()
def forward_layer(
Expand All @@ -224,7 +236,7 @@ def forward_layer(
return (
einsum(
features,
self.get_parameter(f"W")[layer],
self.get_parameter("W")[layer],
"batch_size seq d_features, d_features d_acts -> batch_size seq d_acts",
)
+ self.b[layer]
Expand Down Expand Up @@ -275,6 +287,12 @@ def to_config(self) -> Dict[str, Any]:
},
}

def to_circuit_tracer(self):
return {
"W": self.W,
"b": self.b,
}


class CrosslayerDecoder(SerializableModule):
def __init__(self, d_acts: int, d_features: int, n_layers: int):
Expand All @@ -287,7 +305,7 @@ def __init__(self, d_acts: int, d_features: int, n_layers: int):
f"W_{i}", nn.Parameter(torch.empty((i + 1, d_features, d_acts)))
)
self._is_folded = False
self.register_parameter(f"b", nn.Parameter(torch.empty((n_layers, d_acts))))
self.register_parameter("b", nn.Parameter(torch.empty((n_layers, d_acts))))
self.reset_parameters()

def reset_parameters(self):
Expand Down Expand Up @@ -331,15 +349,15 @@ def forward(
device=features.device,
dtype=features.dtype,
)
for l in range(self.n_layers):
W = self.get_parameter(f"W_{l}")
selected_features = features[:, : l + 1]
for layer_idx in range(self.n_layers):
W = self.get_parameter(f"W_{layer_idx}")
selected_features = features[:, : layer_idx + 1]
l_recons = einsum(
selected_features,
W,
"batch_size n_layers d_features, n_layers d_features d_acts -> batch_size d_acts",
)
recons[:, l, :] = l_recons
recons[:, layer_idx, :] = l_recons
recons = recons + self.b.to(features.dtype)
return recons

Expand Down Expand Up @@ -369,6 +387,32 @@ def to_config(self) -> Dict[str, Any]:
},
}

def to_circuit_tracer(self):
output_decs = []
for source_layer in range(self.n_layers):
output_dec_i = torch.zeros(
[self.d_features, self.n_layers - source_layer, self.d_acts],
)

for k in range(source_layer, self.n_layers):
# get decoder mat for layer i --> k
decoder_w_k = self.get_parameter(f"W_{k}")

dec_i_k = decoder_w_k[source_layer, ...]
assert dec_i_k.shape == (
self.d_features,
self.d_acts,
)

output_dec_i[:, k - source_layer, ...] = dec_i_k

output_decs.append(output_dec_i)

return {
"W": output_decs,
"b": self.b,
}


class CrossLayerTranscoder(SerializableModule):
def __init__(
Expand Down Expand Up @@ -465,3 +509,20 @@ def save_pretrained(self, directory: Path, fold_standardizers: bool = True):
yaml.dump({"model": config}, f)

save_file(self.state_dict(), directory / "checkpoint.safetensors")

def to_circuit_tracer(self):
# NOTE: this mutates the model in-place. Potentially bad, but a tradeoff for copying a huge model.
self.fold()

encoder = self.encoder.to_circuit_tracer()
decoder = self.decoder.to_circuit_tracer()

is_per_layer_decoder = isinstance(self.decoder, Decoder)

config = {
"is_per_layer_decoder": is_per_layer_decoder,
"encoder": encoder,
"decoder": decoder,
}

return config
12 changes: 9 additions & 3 deletions crosslayer_transcoder/model/jumprelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def backward(ctx, grad_output):
grad_input = grad_output.clone()
grad_input[input < 0] = 0

theta_grad = -(theta / bandwidth) * rectangle((input - theta) / bandwidth) * grad_output
theta_grad = (
-(theta / bandwidth) * rectangle((input - theta) / bandwidth) * grad_output
)
return grad_input, theta_grad, None


Expand Down Expand Up @@ -78,7 +80,9 @@ class HeavysideStep(torch.autograd.Function):
def forward(ctx, input, theta, bandwidth):
ctx.save_for_backward(input, theta)
ctx.bandwidth = bandwidth
return torch.where(input - theta > 0, torch.ones_like(input), torch.zeros_like(input))
return torch.where(
input - theta > 0, torch.ones_like(input), torch.zeros_like(input)
)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -87,5 +91,7 @@ def backward(ctx, grad_output):
grad_input = grad_output.clone()
grad_input = grad_output * 0.0

theta_grad = -(1.0 / bandwidth) * rectangle((input - theta) / bandwidth) * grad_output
theta_grad = (
-(1.0 / bandwidth) * rectangle((input - theta) / bandwidth) * grad_output
)
return grad_input, theta_grad, None
Loading