Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ The `config` folder contains example configuration files for different architect
4. **Tensor parallelism works automatically** because PyTorch Lightning handles the distributed setup and PyTorch's Distributed Tensor API shards your model across GPUs without requiring changes to your component code


## Development

Format and lint the codebase:

```bash
uv run ruff format .
uv run ruff check --fix .
```

## Testing

Run the test suite to ensure everything is working correctly:
Expand Down
3 changes: 0 additions & 3 deletions crosslayer_transcoder/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
with efficient multiprocessing data generation.
"""

import torch
import torch.multiprocessing as mp

from .activation_sources import ActivationComputer, DiskActivationSource
from .data_generator import DataGeneratorProcess

Expand Down
5 changes: 2 additions & 3 deletions crosslayer_transcoder/data/activation_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def _extract_activations(self, model: Any, tokens: torch.Tensor, mask: torch.Ten
"""
mlp_ins = []
mlp_outs = []
with model.trace(tokens) as tracer:

with model.trace(tokens):
# Extract from all transformer layers
for i in range(self.n_layers):
# MLP input (after layer norm)
Expand Down Expand Up @@ -147,7 +146,7 @@ def _setup_file(self) -> None:
# self.tensor_handle = self.file_handle[self.accessor]
self.position = 0
except Exception as e:
raise RuntimeError(f"Failed to open activation file {self.file_path}: {e}")
raise RuntimeError(f"Failed to open activation file {self.file_path}: {e}") from e

def get_next_batch(self, batch_size: Optional[int] = None) -> torch.Tensor:
"""
Expand Down
5 changes: 1 addition & 4 deletions crosslayer_transcoder/data/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@
import os
from typing import Optional

import nnsight
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

from crosslayer_transcoder.data import text_dataset
from crosslayer_transcoder.data.activation_sources import ActivationComputer, DiskActivationSource
from crosslayer_transcoder.data.deployment_policy import DeploymentPolicy

Expand Down Expand Up @@ -205,7 +202,7 @@ def cleanup(self):
for child in children:
try:
child.terminate()
except:
except Exception:
pass
except ImportError:
pass # psutil not available
Expand Down
1 change: 0 additions & 1 deletion crosslayer_transcoder/data/deployment_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import logging
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
Expand Down
5 changes: 2 additions & 3 deletions crosslayer_transcoder/data/generation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from crosslayer_transcoder.data import text_dataset
from crosslayer_transcoder.data.activation_sources import ActivationComputer, DiskActivationSource
from crosslayer_transcoder.data.deployment_policy import (
BaseDeploymentPolicy,
DeploymentPolicy,
create_deployment_policy,
)
Expand Down Expand Up @@ -138,7 +137,7 @@ def generation_loop(
# Generate new activations
gen_start = time.time()
activations = self._generate_activations()
gen_time = time.time() - gen_start
time.time() - gen_start

# Take only as many activations as we have indices
num_indices = len(indices_to_refresh)
Expand Down Expand Up @@ -256,7 +255,7 @@ def refill_from_disk(self, batch_size: int = 40_000):
self.shared_buffer.set_activations(indices_to_fill, samples)

# Calculate refill rate for this batch
batch_time = time.time() - batch_start_time
time.time() - batch_start_time
batch_refilled = len(indices_to_fill)
total_refilled += batch_refilled

Expand Down
2 changes: 1 addition & 1 deletion crosslayer_transcoder/data/process_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def update_dashboard(self, status: str, buffer_stats: Dict[str, Any], current_de

# Format uptime with more stable display
if uptime > 60:
uptime_str = f"{int(uptime/60)}:{int(uptime%60):02d}"
uptime_str = f"{int(uptime / 60)}:{int(uptime % 60):02d}"
else:
uptime_str = f"{int(uptime)}s"

Expand Down
5 changes: 1 addition & 4 deletions crosslayer_transcoder/data/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,11 @@
import atexit
import logging
import multiprocessing as mp
import threading
import time
from multiprocessing import shared_memory
from typing import Any, Dict, List, Tuple
from typing import Any, Dict

import numpy as np
import torch
import torch.multiprocessing as torch_mp

# No config imports needed in this module

Expand Down
6 changes: 2 additions & 4 deletions crosslayer_transcoder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ class CrossLayerTranscoderCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
"""Add custom argument linking and configuration."""
# Link model and data parameters that should be consistent
parser.link_arguments(
"data.init_args.n_layers", "model.init_args.nonlinearity.init_args.n_layers"
)
parser.link_arguments("data.init_args.n_layers", "model.init_args.nonlinearity.init_args.n_layers")
parser.link_arguments("data.init_args.n_layers", "model.init_args.n_layers")
parser.link_arguments(
"model.init_args.d_features",
Expand All @@ -34,7 +32,7 @@ def main():
os.environ.setdefault("WANDB_CACHE_DIR", f"{os.getcwd()}/wandb_cache")

# Create CLI with subclass mode to support class_path configuration
cli = CrossLayerTranscoderCLI(
CrossLayerTranscoderCLI(
model_class=L.LightningModule, # Use base class for subclass mode
datamodule_class=L.LightningDataModule, # Use base class for subclass mode
subclass_mode_model=True, # Enable subclass mode for model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def update(self, clt, max_batches=20):
with torch.no_grad():
for i, (tokens, mask) in enumerate(self.loader):
torch.cuda.empty_cache()
print(f"computing replacement model", i)
print("computing replacement model", i)
tokens = self.handle_device(tokens)
mask = self.handle_device(mask)
if i >= max_batches:
Expand Down
16 changes: 10 additions & 6 deletions crosslayer_transcoder/model/clt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def decode(
"from_layer to_layer -> batch_size to_layer d_acts",
)

def forward(self, acts: Float[torch.Tensor, "batch_size n_layers d_acts"]) -> Tuple[
def forward(
self, acts: Float[torch.Tensor, "batch_size n_layers d_acts"]
) -> Tuple[
Float[torch.Tensor, "batch_size n_layers d_features"],
Float[torch.Tensor, "batch_size n_layers d_features"],
Float[torch.Tensor, "batch_size n_layers d_acts"],
Expand Down Expand Up @@ -158,12 +160,12 @@ def __init__(self, d_acts: int, d_features: int, n_layers: int):
self.d_acts = d_acts
self.d_features = d_features
self.n_layers = n_layers
self.register_parameter(f"W", nn.Parameter(torch.empty((n_layers, d_features, d_acts))))
self.register_parameter("W", nn.Parameter(torch.empty((n_layers, d_features, d_acts))))
self.reset_parameters()

def reset_parameters(self):
dec_uniform_thresh = 1 / ((self.d_acts * self.n_layers) ** 0.5)
self.get_parameter(f"W").data.uniform_(-dec_uniform_thresh, dec_uniform_thresh)
self.get_parameter("W").data.uniform_(-dec_uniform_thresh, dec_uniform_thresh)

@torch.no_grad()
def forward_layer(
Expand All @@ -173,7 +175,7 @@ def forward_layer(
features = features[:, :, layer, :]
return einsum(
features,
self.get_parameter(f"W")[layer],
self.get_parameter("W")[layer],
"batch_size seq d_features, d_features d_acts -> batch_size seq d_acts",
)

Expand All @@ -199,7 +201,7 @@ def __init__(self, d_acts: int, d_features: int, n_layers: int):
self.n_layers = n_layers
for i in range(n_layers):
self.register_parameter(f"W_{i}", nn.Parameter(torch.empty((i + 1, d_features, d_acts))))
self.register_parameter(f"b", nn.Parameter(torch.empty((n_layers, d_acts))))
self.register_parameter("b", nn.Parameter(torch.empty((n_layers, d_acts))))
self.reset_parameters()

def reset_parameters(self):
Expand Down Expand Up @@ -273,7 +275,9 @@ def initialize_standardizers(self, batch: Float[torch.Tensor, "batch_size io n_l
self.input_standardizer.initialize_from_batch(batch)
self.output_standardizer.initialize_from_batch(batch)

def forward(self, acts: Float[torch.Tensor, "batch_size n_layers d_acts"]) -> Tuple[
def forward(
self, acts: Float[torch.Tensor, "batch_size n_layers d_acts"]
) -> Tuple[
Float[torch.Tensor, "batch_size n_layers d_features"], # pre_actvs
Float[torch.Tensor, "batch_size n_layers d_features"], # features
Float[torch.Tensor, "batch_size n_layers d_acts"], # recons_norm
Expand Down
55 changes: 12 additions & 43 deletions crosslayer_transcoder/model/clt_lightning.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import gc
import os
import subprocess
import time
from typing import Optional, Tuple

import lightning as L
import psutil
import torch
import torch.nn as nn
from einops import einsum
import wandb
from jaxtyping import Float
from torch.distributed.tensor.parallel import parallelize_module

import wandb
from crosslayer_transcoder.metrics.dead_features import DeadFeatures
from crosslayer_transcoder.metrics.replacement_model_accuracy import (
ReplacementModelAccuracy,
Expand All @@ -22,8 +16,6 @@
CrossLayerTranscoder,
Decoder,
)
from crosslayer_transcoder.model.jumprelu import JumpReLU
from crosslayer_transcoder.model.topk import BatchTopK


class CrossLayerTranscoderModule(L.LightningModule):
Expand Down Expand Up @@ -55,9 +47,7 @@ def __init__(
):
super().__init__(*args, **kwargs)

self.save_hyperparameters(
ignore=["model", "replacement_model", "dead_features"]
)
self.save_hyperparameters(ignore=["model", "replacement_model", "dead_features"])
# torch.cuda.memory._record_memory_history(max_entries=100_000)

# Store pre-constructed modules
Expand Down Expand Up @@ -143,9 +133,7 @@ def log_training_metrics(self, features, recons_norm, recons, mlp_out, batch_idx

ss_err = (mlp_out_norm - recons_norm) ** 2
ss_err = ss_err.sum(dim=0)
ss_total = ((mlp_out_norm - mlp_out_norm.mean(dim=0, keepdim=True)) ** 2).sum(
dim=0
)
ss_total = ((mlp_out_norm - mlp_out_norm.mean(dim=0, keepdim=True)) ** 2).sum(dim=0)
fvu = (ss_err / ss_total).mean() # (n_layers, d_model)
self.log("metrics/fraction_of_variance_unexplained", fvu)
fvu_per_layer = (ss_err / ss_total).mean(dim=-1)
Expand Down Expand Up @@ -189,8 +177,7 @@ def log_training_metrics(self, features, recons_norm, recons, mlp_out, batch_idx
self.log("metrics/recons_standardized_std", recons_norm.std())
self.log(
"metrics/L0_avg_per_layer",
torch.count_nonzero(active_features)
/ (features.shape[0] * features.shape[1]),
torch.count_nonzero(active_features) / (features.shape[0] * features.shape[1]),
)

# Magnitude of feature activations - memory efficient version
Expand All @@ -216,21 +203,15 @@ def log_training_metrics(self, features, recons_norm, recons, mlp_out, batch_idx

# Log L0 table per layer
if batch_idx % 500 == 1:
l0_per_layer = (
torch.count_nonzero(active_features, dim=(0, 2)) / features.shape[0]
)
l0_per_layer = torch.count_nonzero(active_features, dim=(0, 2)) / features.shape[0]

if self.logger and isinstance(self.logger.experiment, wandb.wandb_run.Run):
table = wandb.Table(
data=[[i, v.item()] for i, v in enumerate(l0_per_layer.cpu())],
columns=["layer", "L0"],
)
self.logger.experiment.log(
{
"layers/L0_per_layer": wandb.plot.bar(
table, "layer", "L0", title="L0 per Layer"
)
},
{"layers/L0_per_layer": wandb.plot.bar(table, "layer", "L0", title="L0 per Layer")},
step=self.global_step,
)

Expand Down Expand Up @@ -259,15 +240,11 @@ def log_training_metrics(self, features, recons_norm, recons, mlp_out, batch_idx
):
for layer in range(dead_log_freqs.shape[0]):
self.logger.experiment.log(
{
f"layers/log_feature_density/layer_{layer}": dead_log_freqs[
layer
]
},
{f"layers/log_feature_density/layer_{layer}": dead_log_freqs[layer]},
step=self.global_step,
)
self.logger.experiment.log(
{f"training/log_feature_density": dead_log_freqs.flatten()},
{"training/log_feature_density": dead_log_freqs.flatten()},
step=self.global_step,
)
self.log("training/log_feature_density_mean", dead_log_freqs.mean())
Expand Down Expand Up @@ -433,27 +410,19 @@ def training_step(self, batch, batch_idx):
if isinstance(self.model.decoder, CrosslayerDecoder):
dec_norms = torch.zeros_like(features[:1])
for l in range(self.model.decoder.n_layers):
W = self.model.decoder.get_parameter(
f"W_{l}"
) # (from_layer, d_features, d_acts)
W = self.model.decoder.get_parameter(f"W_{l}") # (from_layer, d_features, d_acts)
dec_norms[:, : l + 1] = dec_norms[:, : l + 1] + (W**2).sum(dim=-1)
dec_norms = dec_norms.sqrt()

elif isinstance(self.model.decoder, Decoder):
dec_norms = torch.sqrt((self.model.decoder.W**2).sum(dim=-1))

weighted_features = features * dec_norms * self.c
self.log(
"model/weighted_features_mean", weighted_features.detach().mean().cpu()
)
self.log("model/weighted_features_mean", weighted_features.detach().mean().cpu())

if self.use_tanh:
weighted_features = torch.tanh(
weighted_features
) # (batch_size, n_layers, d_features)
sparsity = (
self.current_sparsity_penalty() * weighted_features.sum(dim=[1, 2]).mean()
)
weighted_features = torch.tanh(weighted_features) # (batch_size, n_layers, d_features)
sparsity = self.current_sparsity_penalty() * weighted_features.sum(dim=[1, 2]).mean()
self.log("training/sparsity_loss", sparsity)

# Compute Pre-activation Loss
Expand Down
Loading