diff --git a/config/molt-5090.yaml b/config/molt-5090.yaml new file mode 100644 index 0000000..e29ade4 --- /dev/null +++ b/config/molt-5090.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_0005-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0005 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_00001.yaml b/config/molt-5090_20M_tokens_0_00001.yaml new file mode 100644 index 0000000..118e282 --- /dev/null +++ b/config/molt-5090_20M_tokens_0_00001.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00001-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00001" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00001 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_0001.yaml b/config/molt-5090_20M_tokens_0_0001.yaml new file mode 100644 index 0000000..e267c9e --- /dev/null +++ b/config/molt-5090_20M_tokens_0_0001.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_0001-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_0001" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0001 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_00015.yaml b/config/molt-5090_20M_tokens_0_00015.yaml new file mode 100644 index 0000000..0700010 --- /dev/null +++ b/config/molt-5090_20M_tokens_0_00015.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00015-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00015" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00015 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_0005.yaml b/config/molt-5090_20M_tokens_0_0005.yaml new file mode 100644 index 0000000..d8f9330 --- /dev/null +++ b/config/molt-5090_20M_tokens_0_0005.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_0005-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_0005" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0005 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_50M_tokens_0_00015.yaml b/config/molt-5090_50M_tokens_0_00015.yaml new file mode 100644 index 0000000..7c8fb8d --- /dev/null +++ b/config/molt-5090_50M_tokens_0_00015.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 50_000 # 50M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00015-1550-128k-jumprelu-float32-50M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00015_50M" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00015 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-long.yaml b/config/molt-long.yaml new file mode 100644 index 0000000..ebfd470 --- /dev/null +++ b/config/molt-long.yaml @@ -0,0 +1,150 @@ +# Default configuration for CrossLayer Transcoder training +# This file uses Lightning CLI's automatic class construction + +seed_everything: 42 + +trainer: + # max_steps is number of gradient updates, not number of batches + max_steps: 200_000 # 20M tokens + #limit_train_batches: 10_000 + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + precision: "16-mixed" + accelerator: "gpu" + devices: [1] # 4; [1] for cuda:1 + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-0_01" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + #- class_path: crosslayer_transcoder.utils.callbacks.TensorBoardProfilerCallback + # init_args: + # log_dir: "log/profiler" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + # Pre-constructed CrossLayerTranscoder model + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + # Model architecture parameters + d_acts: 768 + n_transforms: 3000 + d_transform: 32 + nonlinearity: + class_path: torch.nn.ReLU + + # Pre-constructed standardizers + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Pre-constructed replacement model + replacement_model: + class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy + init_args: + model_name: "openai-community/gpt2" + device_map: "cuda:1" + loader_batch_size: 2 + + # Pre-constructed dead features metric + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + + # Training parameters + learning_rate: 1e-4 + compile: false + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.01 + c_sparsity: 1 + use_tanh: true + + # Dead features computation settings + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + buffer_size: 1_000_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float16" + max_batch_size: 50000 + + # Model settings for activation generation + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + # Generation settings + generation_batch_size: 10 + refresh_interval: 0.1 + + # Memory settings + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + # File paths + init_file: "/var/local/glang/activations/clt-activations-10M-shuffled_fp16.h5" + + # DataLoader settings + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 # Only provide activations when buffer is at least 20% full + + use_shared_memory: true + + # Device configuration + device_map: "cuda:0" # "cpu", "auto", "cuda:0", "cuda:0,1,2,3" + deployment_policy: "gpu_only" # "cpu_only", "gpu_only", or "dynamic" + + # WandB logging configuration for data generation + wandb_logging: + enabled: true # Enable WandB logging for data generation + project: "debug-molt" # WandB project (should match trainer logger) + group: null # Group name (null = auto-generated from training run) + run_name: "data-generator" # Run name suffix + tags: ["data-generation"] # Tags for the data generation run + save_dir: "./wandb" # Directory for WandB files + log_interval: 5.0 # Logging interval in seconds + +ckpt_path: null \ No newline at end of file diff --git a/config/molt.yaml b/config/molt.yaml new file mode 100644 index 0000000..9ce945a --- /dev/null +++ b/config/molt.yaml @@ -0,0 +1,154 @@ +# Default configuration for CrossLayer Transcoder training +# This file uses Lightning CLI's automatic class construction + +seed_everything: 42 + +trainer: + # max_steps is number of gradient updates, not number of batches + max_steps: 20_000 # 20M tokens + #limit_train_batches: 10_000 + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + # precision: "16-mixed" + accelerator: "gpu" + devices: [2] # 4; [1] for cuda:1 + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-0_0005-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + #- class_path: crosslayer_transcoder.utils.callbacks.TensorBoardProfilerCallback + # init_args: + # log_dir: "log/profiler" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + # Pre-constructed CrossLayerTranscoder model + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + # Model architecture parameters + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + # Pre-constructed standardizers + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Pre-constructed replacement model + replacement_model: + class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy + init_args: + model_name: "openai-community/gpt2" + device_map: "cuda:1" + loader_batch_size: 2 + + # Pre-constructed dead features metric + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + + # Training parameters + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0005 + c_sparsity: 100 + use_tanh: true + + # Dead features computation settings + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + buffer_size: 1_000_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + # Model settings for activation generation + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + # Generation settings + generation_batch_size: 10 + refresh_interval: 0.1 + + # Memory settings + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + # File paths + # init_file: "/var/local/glang/activations/clt-activations-10M-shuffled_fp16.h5" + + # DataLoader settings + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 # Only provide activations when buffer is at least 20% full + + use_shared_memory: true + + # Device configuration + device_map: "cuda:0" # "cpu", "auto", "cuda:0", "cuda:0,1,2,3" + deployment_policy: "gpu_only" # "cpu_only", "gpu_only", or "dynamic" + + # WandB logging configuration for data generation + wandb_logging: + enabled: true # Enable WandB logging for data generation + project: "debug-molt" # WandB project (should match trainer logger) + group: null # Group name (null = auto-generated from training run) + run_name: "data-generator" # Run name suffix + tags: ["data-generation"] # Tags for the data generation run + save_dir: "./wandb" # Directory for WandB files + log_interval: 5.0 # Logging interval in seconds + +ckpt_path: null \ No newline at end of file diff --git a/crosslayer_transcoder/data/datamodule.py b/crosslayer_transcoder/data/datamodule.py index 5df6778..20b91ad 100644 --- a/crosslayer_transcoder/data/datamodule.py +++ b/crosslayer_transcoder/data/datamodule.py @@ -336,7 +336,7 @@ def teardown(self, stage: str = None): """Clean up resources.""" logger.info("Cleaning up activation data loader...") - if self.data_loader and hasattr(self.data_loader, "cleanup"): + if self.data_loader is not None and hasattr(self.data_loader, "cleanup"): self.data_loader.cleanup() if self.data_generator and self.data_generator.is_alive(): diff --git a/crosslayer_transcoder/feature_dash/__init__.py b/crosslayer_transcoder/feature_dash/__init__.py new file mode 100644 index 0000000..1ed8780 --- /dev/null +++ b/crosslayer_transcoder/feature_dash/__init__.py @@ -0,0 +1,51 @@ +"""Feature dashboard for MoLT checkpoints. + +Inspired by sae_vis (https://github.com/callummcdougall/sae_vis), trimmed to a +minimal three-panel view: per-transform activation rate, max-activating +examples, and token-level highlighting within those examples. + +The corpus used for the qualitative pass is OpenWebText +(`Skylion007/openwebtext`) regardless of what the checkpoint was trained on. +""" + +from crosslayer_transcoder.feature_dash.collect import ( + BaseLMRunner, + FeatureSummary, + GateCollector, + collect_features, + window_example, + window_feature_summary, +) +from crosslayer_transcoder.feature_dash.bundle import ( + default_bundle_filename, + make_bundle, + make_bundle_from_disk, +) +from crosslayer_transcoder.feature_dash.dump import dump_dashboard +from crosslayer_transcoder.feature_dash.load import ( + DEFAULT_HF_REPO, + MoltCheckpointMetadata, + infer_molt_arch, + load_molt, + load_molt_from_hf, +) +from crosslayer_transcoder.feature_dash.render import copy_render_assets + +__all__ = [ + "BaseLMRunner", + "DEFAULT_HF_REPO", + "FeatureSummary", + "GateCollector", + "MoltCheckpointMetadata", + "collect_features", + "copy_render_assets", + "default_bundle_filename", + "dump_dashboard", + "infer_molt_arch", + "load_molt", + "load_molt_from_hf", + "make_bundle", + "make_bundle_from_disk", + "window_example", + "window_feature_summary", +] diff --git a/crosslayer_transcoder/feature_dash/__main__.py b/crosslayer_transcoder/feature_dash/__main__.py new file mode 100644 index 0000000..10dd204 --- /dev/null +++ b/crosslayer_transcoder/feature_dash/__main__.py @@ -0,0 +1,254 @@ +"""CLI entrypoint: build a feature dashboard from a MoLT checkpoint. + + python -m crosslayer_transcoder.feature_dash \\ + --hf-filename gpt2-molt-lam-0_00015-50M.ckpt \\ + --out feature_dash/lam_0_00015_50M + +Use `--local-ckpt` instead of `--hf-filename` to load a `.ckpt` from disk. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path + +import torch + +from crosslayer_transcoder.feature_dash.bundle import make_bundle +from crosslayer_transcoder.feature_dash.collect import collect_features +from crosslayer_transcoder.feature_dash.dump import dump_dashboard +from crosslayer_transcoder.feature_dash.load import ( + DEFAULT_HF_REPO, + load_molt, + load_molt_from_hf, +) + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="python -m crosslayer_transcoder.feature_dash", + description="Build a per-transform feature dashboard for a MoLT checkpoint.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + src = p.add_mutually_exclusive_group(required=True) + src.add_argument( + "--hf-filename", + help=f"Filename inside the HF repo (default repo: {DEFAULT_HF_REPO}).", + ) + src.add_argument( + "--local-ckpt", + help="Path to a local .ckpt file. Mutually exclusive with --hf-filename.", + ) + p.add_argument( + "--repo-id", + default=DEFAULT_HF_REPO, + help="HF repo id when using --hf-filename.", + ) + p.add_argument( + "--revision", + default=None, + help="Optional HF revision (commit/tag/branch).", + ) + + p.add_argument( + "--out", required=True, help="Output directory; created if missing." + ) + p.add_argument( + "--layer", + type=int, + default=8, + help="Layer to capture residual at (matches MoLT training layer).", + ) + + p.add_argument( + "--base-model-name", + default="openai-community/gpt2", + help="HF model name for the base LM that produces residuals.", + ) + p.add_argument( + "--dataset-name", + default="Skylion007/openwebtext", + help="HF dataset name for the corpus pass.", + ) + p.add_argument( + "--dataset-split", default="train", help="Dataset split to stream." + ) + + p.add_argument( + "--n-sequences", + type=int, + default=1024, + help="Number of sequences to stream from the corpus.", + ) + p.add_argument( + "--seq-len", + type=int, + default=128, + help="Tokens per sequence; sequences shorter than this are dropped.", + ) + p.add_argument( + "--batch-size", + type=int, + default=16, + help="Sequences per LM forward pass.", + ) + p.add_argument( + "--top-k", + type=int, + default=20, + help="Number of max-activating examples kept per transform.", + ) + p.add_argument( + "--window", + type=int, + default=32, + help="Tokens of context on each side of the peak in the rendered view.", + ) + + p.add_argument( + "--device", + default=None, + help="Torch device. Defaults to cuda if available, else cpu.", + ) + p.add_argument( + "--dtype", + default="float32", + choices=["float32", "float16", "bfloat16"], + help="Base LM dtype.", + ) + + p.add_argument( + "--no-assets", + action="store_true", + help="Skip copying the HTML/CSS/JS templates (data only).", + ) + p.add_argument( + "--bundle", + action="store_true", + help="Also write a single-file `bundle.html` with all data inlined " + "(opens by double-click, no server needed — share with teammates).", + ) + p.add_argument( + "--bundle-only", + action="store_true", + help="Write only `bundle.html` — skip the multi-file dashboard " + "(implies --no-assets and removes data/ after).", + ) + p.add_argument( + "-v", "--verbose", action="store_true", help="Log progress." + ) + + return p + + +_DTYPE = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, +} + + +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + + logging.basicConfig( + level=logging.INFO if args.verbose else logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") + + if args.hf_filename: + molt, meta = load_molt_from_hf( + filename=args.hf_filename, + repo_id=args.repo_id, + revision=args.revision, + device=device, + ) + else: + molt, meta = load_molt(args.local_ckpt, device=device) + + print( + f"loaded MoLT: {meta.n_features} transforms, ranks={meta.ranks}, " + f"d_acts={meta.d_acts}, base LM={meta.base_model_name or args.base_model_name}", + file=sys.stderr, + ) + + collector = collect_features( + molt=molt, + layer=args.layer, + base_model_name=args.base_model_name, + dataset_name=args.dataset_name, + dataset_split=args.dataset_split, + n_sequences=args.n_sequences, + seq_len=args.seq_len, + batch_size=args.batch_size, + top_k=args.top_k, + device=device, + dtype=_DTYPE[args.dtype], + log_every=8 if args.verbose else 0, + ) + + from transformers import GPT2TokenizerFast + + tokenizer = GPT2TokenizerFast.from_pretrained(args.base_model_name) + + out_dir = Path(args.out) + out_dir.mkdir(parents=True, exist_ok=True) + + write_multi_file = not args.bundle_only + write_bundle = args.bundle or args.bundle_only + + if write_multi_file: + dump_dashboard( + collector=collector, + meta=meta, + tokenizer=tokenizer, + out_dir=out_dir, + layer=args.layer, + dataset_name=args.dataset_name, + window=args.window, + copy_assets=not args.no_assets, + ) + + bundle_path = None + if write_bundle: + # Pass the directory so the bundle is named after the checkpoint + # (e.g. `bundle_gpt2-molt-lam-0_00015-50M.html`). + bundle_path = make_bundle( + collector=collector, + meta=meta, + tokenizer=tokenizer, + out_path=out_dir, + layer=args.layer, + dataset_name=args.dataset_name, + window=args.window, + ) + + if args.bundle_only: + # Drop the data/ dir we never wrote — but if the user re-ran into + # the same dir, leave any prior dump alone. + pass + + print(f"\nDashboard written to {out_dir.resolve()}", file=sys.stderr) + if write_multi_file and not args.no_assets: + print( + f" Multi-file: cd {out_dir} && python -m http.server 8050\n" + f" Then open http://localhost:8050/index.html", + file=sys.stderr, + ) + if bundle_path is not None: + size_mb = bundle_path.stat().st_size / 1e6 + print( + f" Bundle: {bundle_path} ({size_mb:.1f} MB) — " + f"open by double-click; share with teammates", + file=sys.stderr, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/crosslayer_transcoder/feature_dash/bundle.py b/crosslayer_transcoder/feature_dash/bundle.py new file mode 100644 index 0000000..1218b84 --- /dev/null +++ b/crosslayer_transcoder/feature_dash/bundle.py @@ -0,0 +1,164 @@ +"""Build a single self-contained `bundle.html` from a dumped dashboard. + +The bundle has metadata + every feature's JSON inlined as `` (and a few related variants). + Replacing `<` with `\\u003c` is the standard hardening. + """ + return s.replace("<", "\\u003c") + + +def make_bundle_from_disk(out_dir: str | Path) -> Path: + """Read `/data/` and write `/bundle_.html`.""" + out_dir = Path(out_dir) + data_dir = out_dir / "data" + if not (data_dir / "metadata.json").is_file(): + raise FileNotFoundError( + f"{data_dir/'metadata.json'} not found — run dump_dashboard first" + ) + + metadata = json.loads((data_dir / "metadata.json").read_text()) + n_features = metadata["n_features"] + width = _feature_id_width(n_features) + + features: dict[str, dict] = {} + for f_id in range(n_features): + path = data_dir / "features" / f"{f_id:0{width}d}.json" + features[str(f_id)] = json.loads(path.read_text()) + + bundle_name = f"bundle_{_ckpt_stem(None, fallback_metadata=metadata)}.html" + return _write_bundle(out_dir / bundle_name, metadata, features) + + +def make_bundle( + collector: GateCollector, + meta: MoltCheckpointMetadata, + tokenizer, + out_path: str | Path, + layer: int, + dataset_name: str = "Skylion007/openwebtext", + seq_len: Optional[int] = None, + window: int = 32, +) -> Path: + """Build a bundle directly from in-memory state — no disk dump needed. + + If `out_path` is a directory (or doesn't exist and ends with a separator), + the bundle is written as `bundle_.html` inside it. Otherwise + `out_path` is treated as the explicit file to write. + """ + out_path = Path(out_path) + treat_as_dir = out_path.is_dir() or ( + not out_path.exists() and out_path.suffix == "" + ) + if treat_as_dir: + out_path = out_path / default_bundle_filename(meta) + + metadata = _metadata_payload( + meta=meta, + collector=collector, + layer=layer, + dataset_name=dataset_name, + seq_len=seq_len if seq_len is not None else collector.T, + top_k=collector.K, + window=window, + ) + + features: dict[str, dict] = {} + for f_id in range(meta.n_features): + summary = collector.feature_summary(f_id) + body = window_feature_summary(summary, tokenizer, window=window) + body["tier"] = meta.feature_tier[f_id] + body["rank"] = meta.feature_rank[f_id] + features[str(f_id)] = body + + return _write_bundle(out_path, metadata, features) + + +def _write_bundle(out_path: Path, metadata: dict, features: dict) -> Path: + template = (TEMPLATE_DIR / "bundle.html").read_text() + css = (TEMPLATE_DIR / "dashboard.css").read_text() + + metadata_json = _sanitize_for_script_tag(json.dumps(metadata, separators=(",", ":"))) + features_json = _sanitize_for_script_tag(json.dumps(features, separators=(",", ":"))) + + html = ( + template + .replace("__BUNDLE_CSS__", css) + .replace("__METADATA_JSON__", metadata_json) + .replace("__FEATURES_JSON__", features_json) + ) + + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(html) + logger.info( + "wrote %s (%.1f MB, %d features)", + out_path, + out_path.stat().st_size / 1e6, + len(features), + ) + return out_path diff --git a/crosslayer_transcoder/feature_dash/collect.py b/crosslayer_transcoder/feature_dash/collect.py new file mode 100644 index 0000000..eeb6adf --- /dev/null +++ b/crosslayer_transcoder/feature_dash/collect.py @@ -0,0 +1,388 @@ +"""Streaming gate collection over OpenWebText. + +For each MoLT transform, accumulate: + 1. firing rate (fraction of tokens with gate > 0), + 2. max activation, + 3. top-K sequences by per-sequence peak activation, with the per-token gate + trace for that one transform (used by the renderer for token highlighting). + +The corpus is `Skylion007/openwebtext` by default — this matches the training +distribution of the HF MoLT checkpoints. + +The collector itself doesn't touch HuggingFace at all; it consumes batches of +`(token_ids, gates)`. `BaseLMRunner` and `collect_features` are the wiring +that produces those batches from a real LM and dataset. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Iterator, Optional + +import torch + +from crosslayer_transcoder.model.molt import Molt + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Pure data-side accumulator (no LM, no dataset). Unit-testable on CPU. +# --------------------------------------------------------------------------- + + +class GateCollector: + """Per-feature firing-rate, max activation, and top-K sequence buffer. + + All state lives on CPU. The collector is pushed (token_ids, gates) batches + by whatever drives it. Memory is bounded by `n_features * top_k * seq_len` + on each of token_ids (int64) and activations (float32). For + `n_features=1550, top_k=20, seq_len=128` that's ~47 MB persistent — fine. + """ + + def __init__(self, n_features: int, top_k: int, seq_len: int): + self.n_features = n_features + self.K = top_k + self.T = seq_len + + self.firing_count = torch.zeros(n_features, dtype=torch.long) + self.total_tokens = 0 + self.max_activation = torch.full((n_features,), -float("inf")) + + # Top-K is keyed by per-sequence peak. -inf in unfilled slots. + self.top_peaks = torch.full((n_features, top_k), -float("inf")) + self.top_token_ids = torch.zeros( + (n_features, top_k, seq_len), dtype=torch.long + ) + self.top_activations = torch.zeros( + (n_features, top_k, seq_len), dtype=torch.float32 + ) + + @torch.no_grad() + def update(self, batch_token_ids: torch.Tensor, gates: torch.Tensor) -> None: + """Fold a batch into the running stats. + + batch_token_ids: (B, T) long, on CPU + gates: (B, T, F) float, on CPU + """ + if batch_token_ids.device.type != "cpu" or gates.device.type != "cpu": + raise ValueError("update expects CPU tensors") + + B, T, F = gates.shape + if F != self.n_features: + raise ValueError(f"gates has {F} features, expected {self.n_features}") + if T != self.T or batch_token_ids.shape != (B, T): + raise ValueError( + f"shape mismatch: token_ids={tuple(batch_token_ids.shape)}, " + f"gates={tuple(gates.shape)}, expected seq_len={self.T}" + ) + + # Firing count + max + active = gates > 0 + self.firing_count += active.sum(dim=(0, 1)).long() + self.total_tokens += B * T + self.max_activation = torch.maximum(self.max_activation, gates.amax(dim=(0, 1))) + + # Top-K merge. Build a (F, K+B) candidate pool of peaks and pick top-K. + peaks = gates.amax(dim=1) # (B, F) + + all_peaks = torch.cat([self.top_peaks, peaks.T], dim=1) # (F, K+B) + # Token ids: old (F, K, T) and batch (1, B, T) broadcast then concat. + batch_tok_expanded = batch_token_ids.unsqueeze(0).expand(F, -1, -1) + all_token_ids = torch.cat( + [self.top_token_ids, batch_tok_expanded], dim=1 + ) # (F, K+B, T) + # Activations: old (F, K, T) and (B, T, F) -> permute -> (F, B, T) then concat. + all_activations = torch.cat( + [self.top_activations, gates.permute(2, 0, 1)], dim=1 + ) # (F, K+B, T) + + new_top_vals, new_top_idx = all_peaks.topk(self.K, dim=1) # (F, K) + + F_idx = torch.arange(F).unsqueeze(1).expand(-1, self.K) + self.top_peaks = new_top_vals + self.top_token_ids = all_token_ids[F_idx, new_top_idx] + self.top_activations = all_activations[F_idx, new_top_idx] + + def activation_rate(self) -> torch.Tensor: + """Per-feature fraction of tokens with gate > 0.""" + if self.total_tokens == 0: + return torch.zeros(self.n_features) + return self.firing_count.float() / float(self.total_tokens) + + def feature_summary(self, feature_id: int) -> "FeatureSummary": + """Dump one feature's collected data into a small dataclass. + + Sequences whose per-sequence peak is <= 0 are dropped — those are + sequences where the feature didn't activate at all, so there's nothing + to highlight. A fully-dead feature returns an empty examples list. + """ + peaks = self.top_peaks[feature_id] + token_ids = self.top_token_ids[feature_id] + activations = self.top_activations[feature_id] + + valid = peaks > 0 + peaks = peaks[valid] + token_ids = token_ids[valid] + activations = activations[valid] + order = torch.argsort(peaks, descending=True) + + return FeatureSummary( + feature_id=feature_id, + activation_rate=self.activation_rate()[feature_id].item(), + max_activation=float(self.max_activation[feature_id].item()) + if torch.isfinite(self.max_activation[feature_id]) + else 0.0, + top_peaks=peaks[order].tolist(), + top_token_ids=token_ids[order].tolist(), + top_activations=activations[order].tolist(), + ) + + +@dataclass +class FeatureSummary: + """Everything one feature contributes to the dashboard JSON.""" + + feature_id: int + activation_rate: float + max_activation: float + top_peaks: list[float] # length <= K + top_token_ids: list[list[int]] # shape (n_examples, T) + top_activations: list[list[float]] # shape (n_examples, T) + + +# --------------------------------------------------------------------------- +# Step 3: window each example around its peak and decode token ids. +# --------------------------------------------------------------------------- + + +def _decode_token(tokenizer, token_id: int) -> str: + """Decode a single token id to a display string. + + Uses single-id decode so byte-level BPE markers (e.g. GPT-2's `Ġ` for a + leading space) come out as actual whitespace — that's what we want for + rendering each token as its own ``. + """ + return tokenizer.decode([int(token_id)]) + + +def window_example( + token_ids: list[int], + activations: list[float], + tokenizer, + window: int = 32, +) -> dict: + """Trim one example to ±`window` tokens around its peak and decode ids. + + Returns the per-example dict that the dashboard JSON expects: + {peak_activation, peak_token_pos, tokens, activations} + where `peak_token_pos` is the index inside the *windowed* arrays. + """ + if len(token_ids) != len(activations): + raise ValueError( + f"length mismatch: {len(token_ids)} ids vs {len(activations)} acts" + ) + if len(token_ids) == 0: + raise ValueError("empty example") + + peak_pos_full = max(range(len(activations)), key=lambda i: activations[i]) + start = max(0, peak_pos_full - window) + end = min(len(token_ids), peak_pos_full + window + 1) + + windowed_ids = token_ids[start:end] + windowed_acts = activations[start:end] + windowed_tokens = [_decode_token(tokenizer, i) for i in windowed_ids] + + return { + "peak_activation": float(activations[peak_pos_full]), + "peak_token_pos": peak_pos_full - start, + "tokens": windowed_tokens, + "activations": [float(a) for a in windowed_acts], + } + + +def window_feature_summary( + summary: FeatureSummary, + tokenizer, + window: int = 32, +) -> dict: + """Build the per-feature dashboard payload from a FeatureSummary. + + Schema (matches FEATURE-DASH.md §2 — minus tier/rank, which the dump step + fills in from MoltCheckpointMetadata): + {feature_id, activation_rate, max_activation, examples: [...]} + """ + examples = [ + window_example(ids, acts, tokenizer, window=window) + for ids, acts in zip(summary.top_token_ids, summary.top_activations) + ] + return { + "feature_id": summary.feature_id, + "activation_rate": summary.activation_rate, + "max_activation": summary.max_activation, + "examples": examples, + } + + +# --------------------------------------------------------------------------- +# Base-LM forward-hook runner (captures residual at h[L].ln_2 input). +# --------------------------------------------------------------------------- + + +class BaseLMRunner: + """Wrap a HF causal LM with a forward pre-hook on `transformer.h[L].ln_2`. + + That tensor — the input to the second LayerNorm — is what MoLT was trained + on (cf. `crosslayer_transcoder/data/activation_sources.py`, which uses + nnsight's `model.transformer.h[i].ln_2.input`). Capturing it via a + PyTorch forward pre-hook is the same thing without the nnsight dep. + """ + + def __init__( + self, + model_name: str, + layer: int, + device: str | torch.device = "cpu", + dtype: torch.dtype = torch.float32, + ): + from transformers import GPT2LMHeadModel # local import: heavy + + self.model = ( + GPT2LMHeadModel.from_pretrained(model_name).to(device).to(dtype).eval() + ) + self.layer = layer + self.device = torch.device(device) + self.dtype = dtype + self._captured: Optional[torch.Tensor] = None + self._hook = self.model.transformer.h[layer].ln_2.register_forward_pre_hook( + self._capture + ) + + def _capture(self, module, args): + # forward_pre_hook: args is the tuple of positional args. ln_2 takes + # `hidden_states` as its only positional arg. + self._captured = args[0] + + @torch.no_grad() + def residual_at_layer(self, token_ids: torch.Tensor) -> torch.Tensor: + """Run the LM and return the captured (B, T, d_acts) residual.""" + self._captured = None + token_ids = token_ids.to(self.device) + self.model(token_ids) + if self._captured is None: + raise RuntimeError( + f"forward pre-hook on layer {self.layer} ln_2 didn't fire" + ) + out = self._captured + self._captured = None + return out + + def close(self): + self._hook.remove() + + +# --------------------------------------------------------------------------- +# Top-level driver: stream tokens, run LM, push gates into collector. +# --------------------------------------------------------------------------- + + +def _iter_token_batches( + dataset_name: str, + dataset_split: str, + tokenizer, + seq_len: int, + batch_size: int, + n_sequences: int, +) -> Iterator[torch.Tensor]: + """Yield (B, T) int64 token-id batches from a streaming HF dataset. + + Skips sequences shorter than `seq_len` to keep all examples the same + length — simpler downstream than masking. + """ + from datasets import load_dataset + + ds = load_dataset(dataset_name, split=dataset_split, streaming=True) + + buf: list[torch.Tensor] = [] + yielded = 0 + for example in ds: + if yielded >= n_sequences: + return + text = example.get("text") or "" + if not text: + continue + enc = tokenizer(text, truncation=True, max_length=seq_len, return_tensors="pt") + ids = enc["input_ids"][0] + if ids.numel() < seq_len: + continue + buf.append(ids[:seq_len]) + if len(buf) == batch_size: + yield torch.stack(buf, dim=0) + yielded += batch_size + buf = [] + if buf: + yield torch.stack(buf, dim=0) + + +def collect_features( + molt: Molt, + layer: int, + base_model_name: str = "openai-community/gpt2", + dataset_name: str = "Skylion007/openwebtext", + dataset_split: str = "train", + n_sequences: int = 1024, + seq_len: int = 128, + batch_size: int = 16, + top_k: int = 20, + device: Optional[str] = None, + dtype: torch.dtype = torch.float32, + log_every: int = 16, +) -> GateCollector: + """Run the base LM over a corpus and collect MoLT gate stats per feature. + + Defaults match the HF MoLT checkpoints (GPT-2, OpenWebText). Returns a + populated `GateCollector` ready to be serialised by Step 4. + """ + from transformers import GPT2TokenizerFast + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + tokenizer = GPT2TokenizerFast.from_pretrained(base_model_name) + runner = BaseLMRunner(base_model_name, layer, device, dtype) + molt = molt.to(device) + + collector = GateCollector( + n_features=molt.n_features, top_k=top_k, seq_len=seq_len + ) + + try: + for batch_idx, tok in enumerate( + _iter_token_batches( + dataset_name=dataset_name, + dataset_split=dataset_split, + tokenizer=tokenizer, + seq_len=seq_len, + batch_size=batch_size, + n_sequences=n_sequences, + ) + ): + resid = runner.residual_at_layer(tok) # (B, T, d_acts) + resid_std = molt.input_standardizer(resid, layer) + pre = molt.e(resid_std) + gates = molt.nonlinearity(pre) # (B, T, n_features) + + collector.update(tok.cpu(), gates.float().cpu()) + + if log_every and batch_idx % log_every == 0: + logger.info( + "batch %d: %d tokens collected, mean fire rate %.4f", + batch_idx, + collector.total_tokens, + collector.activation_rate().mean().item(), + ) + finally: + runner.close() + + return collector diff --git a/crosslayer_transcoder/feature_dash/dump.py b/crosslayer_transcoder/feature_dash/dump.py new file mode 100644 index 0000000..849d402 --- /dev/null +++ b/crosslayer_transcoder/feature_dash/dump.py @@ -0,0 +1,144 @@ +"""Serialise a populated GateCollector + checkpoint metadata to disk. + +Layout produced (matches FEATURE-DASH.md §2): + + / + data/ + metadata.json + features/ + 0000.json + 0001.json + ... + +The renderer in Step 5 will consume these files; nothing here knows about HTML. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict +from pathlib import Path +from typing import Any, Optional + +from crosslayer_transcoder.feature_dash.collect import ( + GateCollector, + window_feature_summary, +) +from crosslayer_transcoder.feature_dash.load import MoltCheckpointMetadata +from crosslayer_transcoder.feature_dash.render import copy_render_assets + +logger = logging.getLogger(__name__) + + +SCHEMA_VERSION = 1 + + +def _metadata_payload( + meta: MoltCheckpointMetadata, + collector: GateCollector, + layer: int, + dataset_name: str, + seq_len: int, + top_k: int, + window: int, +) -> dict[str, Any]: + payload = asdict(meta) + rates = collector.activation_rate().tolist() + # max_activation may be -inf for features that never saw a batch (e.g. an + # empty collector); clamp to 0.0 for JSON / display. + max_acts = [ + float(v) if v != float("-inf") else 0.0 + for v in collector.max_activation.tolist() + ] + payload.update( + { + "schema_version": SCHEMA_VERSION, + "layer": layer, + "n_tokens_collected": collector.total_tokens, + "dashboard_dataset": dataset_name, + "seq_len": seq_len, + "top_k": top_k, + "window": window, + # Per-feature aggregates so the index page renders from one fetch. + "feature_activation_rate": rates, + "feature_max_activation": max_acts, + } + ) + return payload + + +def _feature_payload( + collector: GateCollector, + feature_id: int, + meta: MoltCheckpointMetadata, + tokenizer, + window: int, +) -> dict[str, Any]: + summary = collector.feature_summary(feature_id) + body = window_feature_summary(summary, tokenizer, window=window) + body["tier"] = meta.feature_tier[feature_id] + body["rank"] = meta.feature_rank[feature_id] + return body + + +def dump_dashboard( + collector: GateCollector, + meta: MoltCheckpointMetadata, + tokenizer, + out_dir: str | Path, + layer: int, + dataset_name: str = "Skylion007/openwebtext", + seq_len: Optional[int] = None, + window: int = 32, + copy_assets: bool = True, +) -> Path: + """Write `metadata.json` + `features/.json` per transform, then copy + the static HTML/CSS/JS into `out_dir` (set `copy_assets=False` to skip). + + `seq_len` defaults to the collector's seq_len. Returns the data root. + """ + out_dir = Path(out_dir) + data_dir = out_dir / "data" + feat_dir = data_dir / "features" + feat_dir.mkdir(parents=True, exist_ok=True) + + seq_len = seq_len if seq_len is not None else collector.T + top_k = collector.K + + md = _metadata_payload( + meta=meta, + collector=collector, + layer=layer, + dataset_name=dataset_name, + seq_len=seq_len, + top_k=top_k, + window=window, + ) + (data_dir / "metadata.json").write_text(json.dumps(md, indent=2)) + + width = _feature_id_width(meta.n_features) + for f_id in range(meta.n_features): + payload = _feature_payload(collector, f_id, meta, tokenizer, window) + (feat_dir / f"{f_id:0{width}d}.json").write_text( + json.dumps(payload, separators=(",", ":")) + ) + if f_id and f_id % 200 == 0: + logger.info("dumped feature %d/%d", f_id, meta.n_features) + + logger.info("wrote %d features + metadata to %s", meta.n_features, data_dir) + + if copy_assets: + copy_render_assets(out_dir) + logger.info( + "wrote dashboard assets — open %s/index.html via " + "`python -m http.server` to view", + out_dir, + ) + + return data_dir + + +def _feature_id_width(n_features: int) -> int: + """Zero-padding width for feature filenames; matches the JS in index.js.""" + return max(4, len(str(n_features - 1))) diff --git a/crosslayer_transcoder/feature_dash/load.py b/crosslayer_transcoder/feature_dash/load.py new file mode 100644 index 0000000..ce3fc7c --- /dev/null +++ b/crosslayer_transcoder/feature_dash/load.py @@ -0,0 +1,236 @@ +"""Load a trained MoLT model from a Lightning checkpoint. + +Architecture (d_acts, n_features, ranks, N, n_layers) is inferred from the +state_dict tensor shapes, so we don't need the original training YAML. The +nonlinearity is assumed to be `JumpReLU` and the standardizers are assumed to +be `Dimensionwise{Input,Output}Standardizer` — that's what every MoLT config +in this repo uses today. If a future MoLT variant swaps either of those, this +loader needs to grow a switch on the saved hparams. + +Usage: + # From the Hugging Face hub (default repo: kylelovesllms/molt-sweeps): + molt, meta = load_molt_from_hf("gpt2-molt-lam-0_00015-50M.ckpt") + + # From a local file: + molt, meta = load_molt("checkpoints/lam_0_00015_50M/clt.ckpt") + + # molt is in eval() with grads disabled. meta carries n_features, ranks, + # tier->index mapping, base_model_name, training dataset, etc. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +import torch +from huggingface_hub import hf_hub_download + +from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt +from crosslayer_transcoder.model.standardize import ( + DimensionwiseInputStandardizer, + DimensionwiseOutputStandardizer, +) + +DEFAULT_HF_REPO = "kylelovesllms/molt-sweeps" + + +@dataclass +class MoltCheckpointMetadata: + """Everything the dashboard needs that isn't the model weights themselves.""" + + ckpt_path: str + d_acts: int + n_features: int + n_layers: int + ranks: list[int] + N: int + # Maps each feature index -> (tier_idx, rank). Tier 0 is the lowest-rank-multiplier + # tier (N transforms), tier 1 has 2N transforms, etc. + feature_tier: list[int] = field(default_factory=list) + feature_rank: list[int] = field(default_factory=list) + # Pulled from datamodule_hyper_parameters in the checkpoint, when present. + base_model_name: Optional[str] = None + training_dataset: Optional[str] = None + # Lightning bookkeeping + global_step: Optional[int] = None + epoch: Optional[int] = None + # If the checkpoint was pulled from HF, where it came from. + hf_repo_id: Optional[str] = None + hf_filename: Optional[str] = None + hf_revision: Optional[str] = None + + +def _shape(sd: dict[str, torch.Tensor], key: str) -> tuple[int, ...]: + if key not in sd: + raise KeyError( + f"Expected key '{key}' in checkpoint state_dict — is this a MoLT checkpoint?" + ) + return tuple(sd[key].shape) + + +def infer_molt_arch(state_dict: dict[str, torch.Tensor]) -> dict[str, Any]: + """Pull MoLT architecture out of state_dict tensor shapes.""" + n_features, d_acts = _shape(state_dict, "model.e.weight") + + n_layers = _shape(state_dict, "model.input_standardizer.mean")[0] + + # Walk Us.0, Us.1, ... while the keys exist. ranks come from Us.t.shape[1]. + ranks: list[int] = [] + tier = 0 + while f"model.Us.{tier}" in state_dict: + u_shape = _shape(state_dict, f"model.Us.{tier}") + # Us.t: (N * 2^t, ranks[t], d_acts) + ranks.append(u_shape[1]) + tier += 1 + if not ranks: + raise ValueError( + "No Us.* keys in checkpoint — this doesn't look like a MoLT model." + ) + + # N = (Us.0 first dim) / 2^0 = Us.0.shape[0] + N = _shape(state_dict, "model.Us.0")[0] + + # Sanity: features add up. + expected_features = sum(N * (2**t) for t in range(len(ranks))) + if expected_features != n_features: + raise ValueError( + f"Inferred N={N}, ranks={ranks} implies n_features={expected_features}, " + f"but model.e.weight has {n_features}. Checkpoint may be from an " + "incompatible MoLT variant." + ) + + return { + "d_acts": d_acts, + "n_features": n_features, + "n_layers": n_layers, + "ranks": ranks, + "N": N, + } + + +def _build_tier_index(N: int, ranks: list[int]) -> tuple[list[int], list[int]]: + """Per-feature tier index and rank, in feature-id order.""" + feature_tier: list[int] = [] + feature_rank: list[int] = [] + for t, r in enumerate(ranks): + n_in_tier = N * (2**t) + feature_tier.extend([t] * n_in_tier) + feature_rank.extend([r] * n_in_tier) + return feature_tier, feature_rank + + +def load_molt( + ckpt_path: str | Path, + device: str | torch.device = "cpu", +) -> tuple[Molt, MoltCheckpointMetadata]: + """Reconstruct a `Molt` from a Lightning checkpoint and load its weights. + + The returned model is in eval mode with `requires_grad_(False)`. The + standardizers are marked initialized (their mean/std buffers were saved + with the checkpoint), so the model is ready for forward passes. + """ + ckpt_path = str(ckpt_path) + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + + if "state_dict" not in ckpt: + raise ValueError(f"{ckpt_path} is not a Lightning checkpoint (no state_dict).") + sd = ckpt["state_dict"] + + arch = infer_molt_arch(sd) + d_acts = arch["d_acts"] + n_features = arch["n_features"] + n_layers = arch["n_layers"] + ranks = arch["ranks"] + N = arch["N"] + + nonlin = JumpReLU(theta=0.0, bandwidth=1.0, n_layers=1, d_features=n_features) + in_std = DimensionwiseInputStandardizer(n_layers=n_layers, activation_dim=d_acts) + out_std = DimensionwiseOutputStandardizer(n_layers=n_layers, activation_dim=d_acts) + + molt = Molt( + d_acts=d_acts, + N=N, + nonlinearity=nonlin, + input_standardizer=in_std, + output_standardizer=out_std, + ranks=ranks, + ) + + # Strip the "model." prefix that the LightningModule wrapping adds, and + # drop anything that doesn't belong to the inner Molt (e.g. `last_active`, + # which lives on the LightningModule, or any replacement_model state). + molt_sd = {} + for k, v in sd.items(): + if k.startswith("model."): + molt_sd[k[len("model.") :]] = v + + missing, unexpected = molt.load_state_dict(molt_sd, strict=False) + # Strict=False because the LightningModule may have stored extra buffers + # that aren't part of Molt. Surface anything genuinely missing. + if missing: + raise RuntimeError( + f"Missing keys when loading MoLT weights: {missing}. " + "The checkpoint and inferred architecture disagree." + ) + # `unexpected` is OK here — those are LightningModule-only tensors that + # already got filtered by the prefix strip. If anything slips through it's + # informational, not fatal. + + # Mark standardizers initialized — buffers were loaded from the ckpt. + in_std.is_initialized = True + out_std.is_initialized = True + + molt.eval() + molt.requires_grad_(False) + molt.to(device) + + feature_tier, feature_rank = _build_tier_index(N, ranks) + + dm_hp = ckpt.get("datamodule_hyper_parameters", {}) or {} + meta = MoltCheckpointMetadata( + ckpt_path=ckpt_path, + d_acts=d_acts, + n_features=n_features, + n_layers=n_layers, + ranks=ranks, + N=N, + feature_tier=feature_tier, + feature_rank=feature_rank, + base_model_name=dm_hp.get("model_name"), + training_dataset=dm_hp.get("dataset_name"), + global_step=ckpt.get("global_step"), + epoch=ckpt.get("epoch"), + ) + + return molt, meta + + +def load_molt_from_hf( + filename: str, + repo_id: str = DEFAULT_HF_REPO, + revision: Optional[str] = None, + device: str | torch.device = "cpu", + cache_dir: Optional[str] = None, +) -> tuple[Molt, MoltCheckpointMetadata]: + """Download a MoLT checkpoint from the Hugging Face Hub and load it. + + Defaults `repo_id` to `kylelovesllms/molt-sweeps`. The hub cache handles + re-use across calls; we never copy the file locally. + + Example: + molt, meta = load_molt_from_hf("gpt2-molt-lam-0_00015-50M.ckpt") + """ + local_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + ) + molt, meta = load_molt(local_path, device=device) + meta.hf_repo_id = repo_id + meta.hf_filename = filename + meta.hf_revision = revision + return molt, meta diff --git a/crosslayer_transcoder/feature_dash/render.py b/crosslayer_transcoder/feature_dash/render.py new file mode 100644 index 0000000..c769325 --- /dev/null +++ b/crosslayer_transcoder/feature_dash/render.py @@ -0,0 +1,46 @@ +"""Copy the static dashboard templates into a dump directory. + +The templates are bundled as package data under +`crosslayer_transcoder/feature_dash/templates/`. We copy them rather than +import them as strings so the output is editable / easy to inspect. + +After `dump_dashboard` writes `data/`, calling `copy_render_assets(out_dir)` +puts: + + /index.html + /dashboard.html + /assets/dashboard.css + /assets/dashboard.js + /assets/index.js + +next to it. Open `index.html` via a local HTTP server (file:// blocks fetch +in most browsers). +""" + +from __future__ import annotations + +import shutil +from pathlib import Path + +TEMPLATE_DIR = Path(__file__).parent / "templates" + +# (source filename, destination relative to out_dir) +_ASSET_MAP: list[tuple[str, str]] = [ + ("index.html", "index.html"), + ("dashboard.html", "dashboard.html"), + ("dashboard.css", "assets/dashboard.css"), + ("dashboard.js", "assets/dashboard.js"), + ("index.js", "assets/index.js"), +] + + +def copy_render_assets(out_dir: str | Path) -> Path: + """Copy the HTML/CSS/JS templates into `out_dir`. Returns out_dir.""" + out_dir = Path(out_dir) + (out_dir / "assets").mkdir(parents=True, exist_ok=True) + for src_name, dst_rel in _ASSET_MAP: + src = TEMPLATE_DIR / src_name + dst = out_dir / dst_rel + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(src, dst) + return out_dir diff --git a/crosslayer_transcoder/feature_dash/templates/bundle.html b/crosslayer_transcoder/feature_dash/templates/bundle.html new file mode 100644 index 0000000..14dd9ee --- /dev/null +++ b/crosslayer_transcoder/feature_dash/templates/bundle.html @@ -0,0 +1,244 @@ + + + + + MoLT Feature Dashboard (bundle) + + + +
Loading…
+ + + + + + + diff --git a/crosslayer_transcoder/feature_dash/templates/dashboard.css b/crosslayer_transcoder/feature_dash/templates/dashboard.css new file mode 100644 index 0000000..a43776a --- /dev/null +++ b/crosslayer_transcoder/feature_dash/templates/dashboard.css @@ -0,0 +1,42 @@ +html { font-family: ui-sans-serif, system-ui, -apple-system, sans-serif; color: #1a1a1a; } +body { max-width: 1100px; margin: 0 auto; padding: 1rem 1.5rem 4rem; } + +.topbar { + display: flex; justify-content: space-between; align-items: baseline; + border-bottom: 1px solid #ddd; padding-bottom: 0.5rem; margin-bottom: 1rem; + font-size: 0.9rem; +} +.topbar a { color: #2a5db0; text-decoration: none; } +.topbar a:hover { text-decoration: underline; } +.topbar .ckpt { color: #666; font-family: ui-monospace, monospace; } + +h1 { margin: 0.5rem 0 0.25rem; font-size: 1.5rem; } +h2 { margin-top: 2rem; font-size: 1.1rem; color: #444; } + +.meta { color: #555; display: flex; flex-wrap: wrap; gap: 1rem; font-size: 0.95rem; } +.meta span { white-space: nowrap; } + +.dead { color: #a00; font-style: italic; padding: 1rem; background: #fff4f4; border-radius: 4px; } + +.example { border: 1px solid #e0e0e0; border-radius: 6px; padding: 0.75rem 1rem; margin: 0.75rem 0; background: #fafafa; } +.example-header { font-size: 0.85rem; color: #666; margin-bottom: 0.5rem; font-family: ui-monospace, monospace; } + +.tokens { font-family: ui-monospace, "SF Mono", monospace; font-size: 0.95rem; line-height: 1.7; word-wrap: break-word; } +.tok { padding: 1px 0; border-radius: 2px; white-space: pre; } +.tok.peak { outline: 1.5px solid #c33; outline-offset: -1px; } + +/* Index page table */ +table.features { width: 100%; border-collapse: collapse; font-size: 0.9rem; } +table.features th, table.features td { padding: 0.35rem 0.6rem; text-align: left; border-bottom: 1px solid #eee; } +table.features th { cursor: pointer; user-select: none; background: #f5f5f5; position: sticky; top: 0; } +table.features th.sorted-asc::after { content: " \25B2"; color: #888; } +table.features th.sorted-desc::after { content: " \25BC"; color: #888; } +table.features tr:hover { background: #f9f9f9; } +table.features td.num { text-align: right; font-variant-numeric: tabular-nums; } +table.features td.dead { color: #aaa; } +table.features a { color: #2a5db0; text-decoration: none; } +table.features a:hover { text-decoration: underline; } + +.index-meta { color: #555; font-size: 0.9rem; margin-bottom: 1rem; } +.filter-row { display: flex; gap: 0.75rem; align-items: center; margin: 0.75rem 0 1rem; font-size: 0.9rem; } +.filter-row input { padding: 0.25rem 0.5rem; font-size: 0.9rem; width: 6rem; } diff --git a/crosslayer_transcoder/feature_dash/templates/dashboard.html b/crosslayer_transcoder/feature_dash/templates/dashboard.html new file mode 100644 index 0000000..75d9cba --- /dev/null +++ b/crosslayer_transcoder/feature_dash/templates/dashboard.html @@ -0,0 +1,22 @@ + + + + + MoLT Feature Dashboard + + + +
+ ← index + +
+ +

Loading…

+
+ +

Top activating examples

+
+ + + + diff --git a/crosslayer_transcoder/feature_dash/templates/dashboard.js b/crosslayer_transcoder/feature_dash/templates/dashboard.js new file mode 100644 index 0000000..914d804 --- /dev/null +++ b/crosslayer_transcoder/feature_dash/templates/dashboard.js @@ -0,0 +1,102 @@ +// Per-feature view. Reads ?feature=, fetches data/features/.json +// and data/metadata.json, renders header + token-highlighted examples. + +const params = new URLSearchParams(location.search); +const featureId = parseInt(params.get('feature') ?? '0', 10); + +function pad(id, width) { + return String(id).padStart(width, '0'); +} + +function widthFor(nFeatures) { + return Math.max(4, String(nFeatures - 1).length); +} + +async function main() { + const md = await fetch('data/metadata.json').then(r => r.json()); + const w = widthFor(md.n_features); + const f = await fetch(`data/features/${pad(featureId, w)}.json`).then(r => r.json()); + + document.title = `Feature #${f.feature_id} — MoLT dashboard`; + document.getElementById('ckpt').textContent = + (md.hf_filename || md.ckpt_path) + ` · layer ${md.layer}`; + document.getElementById('title').textContent = `Feature #${f.feature_id}`; + + const meta = document.getElementById('meta'); + meta.innerHTML = ''; + addMeta(meta, `tier ${f.tier} · rank ${f.rank}`); + addMeta(meta, `activation rate: ${(f.activation_rate * 100).toFixed(3)}%`); + addMeta(meta, `max activation: ${f.max_activation.toFixed(3)}`); + addMeta(meta, `examples: ${f.examples.length}`); + + // Prev/next nav. + const nav = document.createElement('span'); + if (featureId > 0) { + nav.appendChild(link(`?feature=${featureId - 1}`, '← prev')); + nav.appendChild(document.createTextNode(' ')); + } + if (featureId < md.n_features - 1) { + nav.appendChild(link(`?feature=${featureId + 1}`, 'next →')); + } + meta.appendChild(nav); + + const root = document.getElementById('examples'); + if (f.examples.length === 0) { + root.innerHTML = '
This transform did not fire on any token in the sampled corpus.
'; + return; + } + for (const ex of f.examples) { + root.appendChild(renderExample(ex, f.max_activation)); + } +} + +function addMeta(parent, text) { + const s = document.createElement('span'); + s.textContent = text; + parent.appendChild(s); +} + +function link(href, text) { + const a = document.createElement('a'); + a.href = href; + a.textContent = text; + return a; +} + +function renderExample(ex, globalMax) { + const wrap = document.createElement('div'); + wrap.className = 'example'; + + const header = document.createElement('div'); + header.className = 'example-header'; + header.textContent = `peak ${ex.peak_activation.toFixed(3)} · position ${ex.peak_token_pos} of ${ex.tokens.length}`; + wrap.appendChild(header); + + const tokens = document.createElement('div'); + tokens.className = 'tokens'; + // Scale alpha by the per-feature global max so different sequences are + // visually comparable within the same feature. + const scale = globalMax > 0 ? globalMax : 1.0; + for (let i = 0; i < ex.tokens.length; i++) { + const span = document.createElement('span'); + span.className = 'tok'; + if (i === ex.peak_token_pos) span.classList.add('peak'); + const a = ex.activations[i]; + const alpha = Math.max(0, Math.min(1, a / scale)); + span.style.backgroundColor = `rgba(220, 50, 50, ${alpha.toFixed(3)})`; + span.title = `act ${a.toFixed(4)}`; + // Preserve raw whitespace inside the span (CSS white-space: pre handles this). + span.textContent = ex.tokens[i] === '' ? ' ' : ex.tokens[i]; + tokens.appendChild(span); + } + wrap.appendChild(tokens); + + return wrap; +} + +main().catch(err => { + document.getElementById('examples').innerHTML = + `
Failed to load feature data: ${err.message}. ` + + `If you opened this with file://, serve the directory first ` + + `(e.g. python -m http.server).
`; +}); diff --git a/crosslayer_transcoder/feature_dash/templates/index.html b/crosslayer_transcoder/feature_dash/templates/index.html new file mode 100644 index 0000000..c00496f --- /dev/null +++ b/crosslayer_transcoder/feature_dash/templates/index.html @@ -0,0 +1,37 @@ + + + + + MoLT Feature Dashboard — index + + + +

MoLT feature dashboard

+
Loading…
+
+ + + + +
+ + + + + + + + + + + +
idtierrankratemax act
+ + + diff --git a/crosslayer_transcoder/feature_dash/templates/index.js b/crosslayer_transcoder/feature_dash/templates/index.js new file mode 100644 index 0000000..c2ee9bb --- /dev/null +++ b/crosslayer_transcoder/feature_dash/templates/index.js @@ -0,0 +1,120 @@ +// Index page. One fetch of metadata.json populates the sortable table. + +let allRows = []; +let sortKey = 'activation_rate'; +let sortDir = 'desc'; +let metadata = null; + +function pad(id, width) { + return String(id).padStart(width, '0'); +} + +async function main() { + metadata = await fetch('data/metadata.json').then(r => r.json()); + + document.getElementById('meta').textContent = + `${metadata.n_features} transforms · ` + + `${(metadata.hf_filename || metadata.ckpt_path)} · layer ${metadata.layer} · ` + + `${metadata.n_tokens_collected.toLocaleString()} tokens from ${metadata.dashboard_dataset}`; + + // Build the row data. + for (let i = 0; i < metadata.n_features; i++) { + allRows.push({ + feature_id: i, + tier: metadata.feature_tier[i], + rank: metadata.feature_rank[i], + activation_rate: metadata.feature_activation_rate[i], + max_activation: metadata.feature_max_activation[i], + }); + } + + // Populate tier filter. + const uniqueTiers = [...new Set(metadata.feature_tier)].sort((a, b) => a - b); + const tierSel = document.getElementById('tier-filter'); + for (const t of uniqueTiers) { + const opt = document.createElement('option'); + opt.value = String(t); + opt.textContent = `tier ${t} (rank ${metadata.ranks[t]})`; + tierSel.appendChild(opt); + } + + // Wire interactions. + document.querySelectorAll('th[data-key]').forEach(th => { + th.addEventListener('click', () => { + const k = th.dataset.key; + if (sortKey === k) { + sortDir = sortDir === 'desc' ? 'asc' : 'desc'; + } else { + sortKey = k; + sortDir = (k === 'feature_id' || k === 'tier') ? 'asc' : 'desc'; + } + render(); + }); + }); + ['change', 'input'].forEach(evt => { + tierSel.addEventListener(evt, render); + document.getElementById('min-rate').addEventListener(evt, render); + document.getElementById('max-rate').addEventListener(evt, render); + }); + + render(); +} + +function render() { + const tier = document.getElementById('tier-filter').value; + const minRate = parseFloat(document.getElementById('min-rate').value) || 0; + const maxRate = parseFloat(document.getElementById('max-rate').value); + const maxRateOk = Number.isFinite(maxRate) ? maxRate : 1; + + const filtered = allRows.filter(r => + (tier === '' || String(r.tier) === tier) && + r.activation_rate >= minRate && + r.activation_rate <= maxRateOk + ); + + filtered.sort((a, b) => { + const av = a[sortKey], bv = b[sortKey]; + if (av < bv) return sortDir === 'asc' ? -1 : 1; + if (av > bv) return sortDir === 'asc' ? 1 : -1; + return a.feature_id - b.feature_id; + }); + + document.querySelectorAll('th[data-key]').forEach(th => { + th.classList.remove('sorted-asc', 'sorted-desc'); + if (th.dataset.key === sortKey) { + th.classList.add(sortDir === 'asc' ? 'sorted-asc' : 'sorted-desc'); + } + }); + + document.getElementById('count').textContent = + `${filtered.length.toLocaleString()} of ${allRows.length.toLocaleString()} shown`; + + const tbody = document.querySelector('#features tbody'); + // Render up to 2000 rows; the user can filter further if they need more. + const cap = 2000; + const rows = filtered.slice(0, cap); + const html = rows.map(r => { + const isDead = r.activation_rate === 0; + return ` + #${r.feature_id} + ${r.tier} + ${r.rank} + ${(r.activation_rate * 100).toFixed(3)}% + ${r.max_activation.toFixed(3)} + `; + }).join(''); + tbody.innerHTML = html; + if (filtered.length > cap) { + tbody.insertAdjacentHTML( + 'beforeend', + `${filtered.length - cap} more rows hidden — narrow the filter` + ); + } +} + +main().catch(err => { + document.getElementById('meta').innerHTML = + `Failed to load metadata.json: ${err.message}. ` + + `If you opened this with file://, serve the directory first ` + + `(e.g. python -m http.server).`; +}); diff --git a/crosslayer_transcoder/model/__init__.py b/crosslayer_transcoder/model/__init__.py index e4f0aae..6a88618 100644 --- a/crosslayer_transcoder/model/__init__.py +++ b/crosslayer_transcoder/model/__init__.py @@ -4,11 +4,13 @@ from .clt import CrossLayerTranscoder from .clt_lightning import CrossLayerTranscoderModule +from .molt import Molt from .topk import BatchTopK, PerLayerBatchTopK, PerLayerTopK __all__ = [ "CrossLayerTranscoder", "CrossLayerTranscoderModule", + "Molt", "BatchTopK", "PerLayerTopK", "PerLayerBatchTopK", diff --git a/crosslayer_transcoder/model/clt_lightning.py b/crosslayer_transcoder/model/clt_lightning.py index 8e4a020..2256028 100644 --- a/crosslayer_transcoder/model/clt_lightning.py +++ b/crosslayer_transcoder/model/clt_lightning.py @@ -2,7 +2,7 @@ import os import subprocess import time -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import lightning as L import psutil @@ -23,6 +23,7 @@ Decoder, ) from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt from crosslayer_transcoder.model.topk import BatchTopK @@ -30,7 +31,7 @@ class CrossLayerTranscoderModule(L.LightningModule): def __init__( self, # Pre-constructed modules - model: CrossLayerTranscoder, + model: Union[CrossLayerTranscoder, Molt], replacement_model: Optional[ReplacementModelAccuracy] = None, dead_features: Optional[DeadFeatures] = None, # Training parameters @@ -85,17 +86,23 @@ def __init__( self.beta2 = beta2 self.log_metrics_every = log_metrics_every - assert self.model.encoder.n_layers == self.model.decoder.n_layers, ( - "Encoder and decoder must have the same number of layers" - ) + if isinstance(self.model, Molt): + self.register_buffer( + "last_active", + torch.zeros((self.model.n_features,), dtype=torch.long), + ) + else: + assert self.model.encoder.n_layers == self.model.decoder.n_layers, ( + "Encoder and decoder must have the same number of layers" + ) - self.register_buffer( - "last_active", - torch.zeros( - (self.model.encoder.n_layers, self.model.encoder.d_features), - dtype=torch.long, - ), - ) + self.register_buffer( + "last_active", + torch.zeros( + (self.model.encoder.n_layers, self.model.encoder.d_features), + dtype=torch.long, + ), + ) def configure_model(self): # Apply compilation if requested @@ -565,3 +572,68 @@ def training_step(self, batch, batch_idx): torch.cuda.memory._record_memory_history(enabled=None) exit() return loss + + +class MoltModule(CrossLayerTranscoderModule): + def __init__( + self, + lambda_sparsity: float = 0.0002, + c_sparsity: float = 0.1, + use_tanh: bool = True, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._lambda = lambda_sparsity + self.c = c_sparsity + self.use_tanh = use_tanh + + def current_sparsity_penalty(self): + n_steps = self.trainer.max_steps + current_step = ( + self.global_step + ) # use global step instead of batch idx to work with gradient accumulation + cur_lambda = self._lambda * (current_step / n_steps) + self.log("training/sparsity_penalty", cur_lambda) + return cur_lambda + + def forward(self, batch, layer): + return self.model.forward(batch, layer) + + def training_step(self, batch, batch_idx): + if batch_idx == 0: + self.model.initialize_standardizers(batch) + self.log("model/d_latents", self.model.d_latents) + self.log("model/n_features", self.model.n_features) + + layer = 8 + + # Forward pass + resid, mlp_out = batch[:, 0], batch[:, 1] + resid = resid[:, layer] + mlp_out = mlp_out[:, layer] + gate, recons_norm, recons = self.model.forward(resid, layer) + + self.update_dead_features(gate) + # Compute MSE loss + mse = (recons_norm - self.model.output_standardizer.standardize(mlp_out, layer)) ** 2 + + # Compute Sparsity Loss + norms = self.model.transform_norm() + weighted_norms = norms * gate + self.log("model/weighted_norms_mean", weighted_norms.detach().mean().cpu()) + + if self.use_tanh: + weighted_norms = torch.tanh(weighted_norms * self.c) + sparsity = self.current_sparsity_penalty() * weighted_norms.sum(dim=-1).mean() + self.log("training/sparsity_loss", sparsity) + self.log("L0", (gate > 0.0).float().sum() / gate.shape[0]) + + loss = mse.mean() + sparsity + self.log("training/mse", mse.mean()) + self.log("training/loss", loss) + + if batch_idx % self.log_metrics_every == 0: + pass + + return loss diff --git a/crosslayer_transcoder/model/jumprelu.py b/crosslayer_transcoder/model/jumprelu.py index d21b7ec..c977830 100644 --- a/crosslayer_transcoder/model/jumprelu.py +++ b/crosslayer_transcoder/model/jumprelu.py @@ -52,7 +52,8 @@ def backward(ctx, grad_output): class JumpReLU(SerializableModule): def __init__(self, theta=0.0, bandwidth=1.0, n_layers=12, d_features=768 * 8): super().__init__() - self.theta = nn.Parameter(torch.full((1, n_layers, d_features), theta)) + shape = (1, n_layers, d_features) if n_layers > 1 else (1, d_features) + self.theta = nn.Parameter(torch.full(shape, theta)) self.register_buffer("bandwidth", torch.tensor(bandwidth)) self._init_theta = theta self.n_layers = n_layers diff --git a/crosslayer_transcoder/model/molt.py b/crosslayer_transcoder/model/molt.py new file mode 100644 index 0000000..afb3fd3 --- /dev/null +++ b/crosslayer_transcoder/model/molt.py @@ -0,0 +1,93 @@ +import einops +import torch +import torch.nn as nn +from jaxtyping import Float + + +class Molt(nn.Module): + def __init__( + self, + d_acts: int, + N: int, + nonlinearity: nn.Module, + input_standardizer: nn.Module, + output_standardizer: nn.Module, + ranks: list[int] = [512, 256, 128, 64, 32], + ): + super().__init__() + + self.d_acts = d_acts + self.nonlinearity = nonlinearity + self.input_standardizer = input_standardizer + self.output_standardizer = output_standardizer + Us = [] + Vs = [] + rank_multiplier = 1 + n_features = 0 + d_latents = 0 + for rank in ranks: + Us.append(nn.Parameter(torch.empty(N * rank_multiplier, rank, d_acts))) + Vs.append(nn.Parameter(torch.empty(N * rank_multiplier, d_acts, rank))) + n_features += N * rank_multiplier + d_latents += N * rank_multiplier * rank + rank_multiplier *= 2 + self.n_features = n_features + self.e = nn.Linear(d_acts, n_features) + self.Us = nn.ParameterList(Us) + self.Vs = nn.ParameterList(Vs) + + print(f"d_latents (transcoder equivalent): {d_latents}") + self.d_latents = d_latents + + self.reset_parameters() + + def reset_parameters(self): + for U in self.Us: + nn.init.xavier_uniform_(U) + for V in self.Vs: + nn.init.xavier_uniform_(V) + + def transform_norm(self): + norms = [] + for U, V in zip(self.Us, self.Vs): + uv = einops.einsum( + U, + V, + "n_transforms d_transform d_acts_out, n_transforms d_acts_in d_transform -> n_transforms d_acts_in d_acts_out", + ) + norms.append(torch.norm(uv, dim=(1, 2))) + return torch.cat(norms, dim=0) + + def forward( + self, acts: Float[torch.Tensor, "batch_size d_acts"], layer: int + ) -> Float[torch.Tensor, "batch_size d_acts"]: + acts = self.input_standardizer(acts, layer) + pre_actvs = self.e(acts) + gate = self.nonlinearity(pre_actvs) # (batch, n_transforms) + + raw_recons = [] + for U, V in zip(self.Us, self.Vs): + latents = einops.einsum( + acts, + V, + "batch d_acts, n_transforms d_acts d_transform -> batch n_transforms d_transform", + ) + raw_recons.append( + einops.einsum( + latents, + U, + "batch n_transforms d_transform, n_transforms d_transform d_acts -> batch n_transforms d_acts", + ) + ) + + raw_recons = torch.cat(raw_recons, dim=1) + + weighted_recons = gate.unsqueeze(-1) * raw_recons + recons_norm = weighted_recons.sum(dim=1) + + recons = self.output_standardizer(recons_norm, layer) + return gate, recons_norm, recons + + def initialize_standardizers(self, batch: Float[torch.Tensor, "batch_size io n_layers d_acts"]): + self.input_standardizer.initialize_from_batch(batch) + self.output_standardizer.initialize_from_batch(batch) diff --git a/pyproject.toml b/pyproject.toml index 849e873..52e8de7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,9 @@ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.wheel.force-include] +"crosslayer_transcoder/feature_dash/templates" = "crosslayer_transcoder/feature_dash/templates" + [project] name = "crosslayer-transcoder" version = "0.1.0" diff --git a/tests/test_feature_dash_bundle.py b/tests/test_feature_dash_bundle.py new file mode 100644 index 0000000..d680bee --- /dev/null +++ b/tests/test_feature_dash_bundle.py @@ -0,0 +1,250 @@ +"""Tests for the single-file `bundle.html` builder. + +Two paths to the bundle: + - `make_bundle_from_disk(out_dir)`: read an existing dump, write bundle.html. + - `make_bundle(collector, meta, tokenizer, out_path, ...)`: build directly + from in-memory state, no disk dump needed. + +Both should produce a self-contained file (no `fetch(` calls, all data inlined) +that contains every feature's payload. +""" + +from __future__ import annotations + +import json +import re +from pathlib import Path + +import pytest +import torch + +from crosslayer_transcoder.feature_dash.bundle import ( + default_bundle_filename, + make_bundle, + make_bundle_from_disk, +) +from crosslayer_transcoder.feature_dash.collect import GateCollector +from crosslayer_transcoder.feature_dash.dump import dump_dashboard +from crosslayer_transcoder.feature_dash.load import MoltCheckpointMetadata + + +class _StubTokenizer: + def decode(self, ids): + return f" tok{ids[0]}" + + +def _populate_collector(F: int, K: int, T: int) -> GateCollector: + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + tok = torch.arange(2 * T, dtype=torch.long).reshape(2, T) + g = torch.zeros(2, T, F) + g[0, 1, 0] = 1.0 + g[1, 3, 0] = 2.5 + g[0, 0, 1] = 0.5 + coll.update(tok, g) + return coll + + +def _meta(F: int, ranks: list[int], N: int) -> MoltCheckpointMetadata: + feature_tier: list[int] = [] + feature_rank: list[int] = [] + for t, r in enumerate(ranks): + n_in_tier = N * (2**t) + feature_tier.extend([t] * n_in_tier) + feature_rank.extend([r] * n_in_tier) + return MoltCheckpointMetadata( + ckpt_path="dummy.ckpt", + d_acts=8, + n_features=F, + n_layers=2, + ranks=ranks, + N=N, + feature_tier=feature_tier, + feature_rank=feature_rank, + base_model_name="stub", + training_dataset="stub-ds", + global_step=123, + epoch=0, + ) + + +def _extract_inlined_json(html: str, script_id: str) -> dict: + """Pull a `` payload.""" + m = re.search( + rf'', + html, + flags=re.DOTALL, + ) + assert m, f"missing inlined script #{script_id}" + raw = m.group(1) + # Reverse the `<` hardening before parsing. + raw = raw.replace(r"<", "<") + return json.loads(raw) + + +def test_bundle_contains_no_fetch_calls(tmp_path: Path): + """Bundle must work from file:// — no network/disk fetches allowed.""" + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + + out = make_bundle( + collector=coll, + meta=meta, + tokenizer=_StubTokenizer(), + out_path=tmp_path / "bundle.html", + layer=8, + window=2, + ) + text = out.read_text() + # No fetch(... call anywhere — neither real nor templated. + assert "fetch(" not in text + + +def test_bundle_inlines_metadata_and_all_features(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + + bundle_path = make_bundle( + collector=coll, meta=meta, tokenizer=_StubTokenizer(), + out_path=tmp_path, layer=8, window=2, + ) + # When out_path is a directory, the bundle is named after the checkpoint. + # `_meta()` sets ckpt_path="dummy.ckpt" and no hf_filename, so stem="dummy". + assert bundle_path == tmp_path / "bundle_dummy.html" + + html = bundle_path.read_text() + md = _extract_inlined_json(html, "metadata") + feats = _extract_inlined_json(html, "features") + + assert md["n_features"] == F + assert md["layer"] == 8 + assert md["feature_tier"] == [0, 1, 1] + assert md["feature_rank"] == [4, 2, 2] + assert len(md["feature_activation_rate"]) == F + assert len(md["feature_max_activation"]) == F + + assert set(feats.keys()) == {"0", "1", "2"} + for k, v in feats.items(): + assert v["feature_id"] == int(k) + assert {"tier", "rank", "activation_rate", "max_activation", "examples"} <= set(v.keys()) + + # Feature 0 has two real examples and feature 2 is dead. + assert len(feats["0"]["examples"]) == 2 + assert feats["2"]["examples"] == [] + + +def test_bundle_css_is_inlined(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + out = make_bundle( + collector=coll, meta=meta, tokenizer=_StubTokenizer(), + out_path=tmp_path / "bundle.html", layer=8, window=2, + ) + html = out.read_text() + # No external stylesheet link, and a non-empty inline " in html + assert "table.features" in html # one of our CSS rules + + +def test_bundle_from_disk_reads_dumped_dashboard(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + + dump_dashboard( + collector=coll, meta=meta, tokenizer=_StubTokenizer(), + out_dir=tmp_path, layer=8, window=2, copy_assets=False, + ) + bundle_path = make_bundle_from_disk(tmp_path) + + # ckpt_path="dummy.ckpt" -> bundle_dummy.html + assert bundle_path == tmp_path / "bundle_dummy.html" + html = bundle_path.read_text() + md = _extract_inlined_json(html, "metadata") + feats = _extract_inlined_json(html, "features") + assert md["n_features"] == F + assert len(feats) == F + + +def test_bundle_from_disk_errors_when_data_missing(tmp_path: Path): + with pytest.raises(FileNotFoundError): + make_bundle_from_disk(tmp_path) + + +def test_default_bundle_filename_prefers_hf_filename(): + F, N = 3, 1 + meta = _meta(F, [4, 2], N) + meta.hf_filename = "gpt2-molt-lam-0_00015-50M.ckpt" + meta.ckpt_path = "/some/random/local/path.ckpt" + assert default_bundle_filename(meta) == "bundle_gpt2-molt-lam-0_00015-50M.html" + + +def test_default_bundle_filename_falls_back_to_ckpt_path(): + F, N = 3, 1 + meta = _meta(F, [4, 2], N) + meta.hf_filename = None + meta.ckpt_path = "/runs/exp42/checkpoint-final.ckpt" + assert default_bundle_filename(meta) == "bundle_checkpoint-final.html" + + +def test_default_bundle_filename_handles_unsafe_chars(): + F, N = 3, 1 + meta = _meta(F, [4, 2], N) + meta.hf_filename = "weird/name with spaces.ckpt" + # The stem is "name with spaces" (Path.stem already drops the directory). + # Spaces are sanitised to '-'. + assert default_bundle_filename(meta) == "bundle_name-with-spaces.html" + + +def test_make_bundle_explicit_filepath_is_respected(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + + explicit = tmp_path / "my-custom-name.html" + out = make_bundle( + collector=coll, meta=meta, tokenizer=_StubTokenizer(), + out_path=explicit, layer=8, window=2, + ) + # When the user passes a file path, we don't override it. + assert out == explicit + + +def test_bundle_handles_script_tag_in_token_text(tmp_path: Path): + """A token string containing `` must not break the bundle. + + The sanitiser replaces `<` with `\\u003c`; the JS in the bundle parses the + JSON via `JSON.parse`, which decodes the unicode escape back. We just + verify here that the raw HTML doesn't contain a literal `` inside + the data payload. + """ + + class _NastyTokenizer: + def decode(self, ids): + return "BAD" + + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + + out = make_bundle( + collector=coll, meta=meta, tokenizer=_NastyTokenizer(), + out_path=tmp_path / "bundle.html", layer=8, window=2, + ) + html = out.read_text() + # Find the data scripts and check their content is sanitised. + m = re.search( + r'', + html, flags=re.DOTALL, + ) + assert m + payload = m.group(1) + # The sanitised payload must not contain a literal "') == 3 diff --git a/tests/test_feature_dash_collect.py b/tests/test_feature_dash_collect.py new file mode 100644 index 0000000..dcbdb85 --- /dev/null +++ b/tests/test_feature_dash_collect.py @@ -0,0 +1,158 @@ +"""Tests for the GateCollector accumulator (no LM, no dataset). + +Synthetic gate batches go in, we check that firing counts, activation rates, +top-K ordering, and per-token activation traces come out right. +""" + +from __future__ import annotations + +import pytest +import torch + +from crosslayer_transcoder.feature_dash.collect import GateCollector + + +F = 3 # features +K = 4 # top-K +T = 6 # seq_len + + +def _zero_gates(B: int) -> torch.Tensor: + return torch.zeros(B, T, F) + + +def test_firing_count_and_total_tokens(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + tok = torch.arange(2 * T, dtype=torch.long).reshape(2, T) + g = _zero_gates(2) + g[0, 1, 0] = 0.5 + g[0, 4, 0] = 1.5 + g[1, 2, 0] = 0.0 # 0 doesn't count as active (gate > 0) + g[1, 3, 1] = 0.7 + + coll.update(tok, g) + + assert coll.total_tokens == 2 * T + assert coll.firing_count.tolist() == [2, 1, 0] + rates = coll.activation_rate() + assert torch.allclose(rates, torch.tensor([2 / 12, 1 / 12, 0.0])) + + +def test_max_activation_tracks_global_max_across_batches(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + g1 = _zero_gates(1) + g1[0, 0, 0] = 1.0 + coll.update(torch.zeros(1, T, dtype=torch.long), g1) + + g2 = _zero_gates(1) + g2[0, 0, 0] = 0.5 + g2[0, 0, 1] = 4.2 + coll.update(torch.zeros(1, T, dtype=torch.long), g2) + + # Once a batch lands, max_activation rises from -inf to >=0 even for + # non-firing features (max of -inf and 0 is 0). That's fine for the + # dashboard — feature_summary surfaces 0.0 for dead features. + assert torch.allclose( + coll.max_activation, torch.tensor([1.0, 4.2, 0.0]), atol=1e-6 + ) + + +def test_topk_orders_by_per_sequence_peak(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + # 5 sequences, all activate feature 0 with different peaks at different positions. + B = 5 + tok = torch.arange(B * T, dtype=torch.long).reshape(B, T) + g = _zero_gates(B) + peaks_in = [0.1, 3.0, 2.0, 0.5, 5.0] + pos_in = [0, 2, 5, 1, 4] + for b, (p, pos) in enumerate(zip(peaks_in, pos_in)): + g[b, pos, 0] = p + coll.update(tok, g) + + summary = coll.feature_summary(0) + + # Top-K = 4: should drop the 0.1 peak. + assert summary.top_peaks == pytest.approx([5.0, 3.0, 2.0, 0.5]) + assert summary.activation_rate == pytest.approx(5 / (B * T)) + assert summary.max_activation == 5.0 + # Token ids of top sequence should match seq with peak 5.0 (b=4). + assert summary.top_token_ids[0] == tok[4].tolist() + # The activation trace's argmax should match where we placed the peak. + trace0 = summary.top_activations[0] + assert max(range(T), key=lambda i: trace0[i]) == pos_in[4] + + +def test_topk_merges_across_multiple_batches(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + # Batch 1: peaks 1, 2, 3 for feature 0. + tok1 = torch.full((3, T), 11, dtype=torch.long) + g1 = _zero_gates(3) + g1[0, 0, 0] = 1.0 + g1[1, 0, 0] = 2.0 + g1[2, 0, 0] = 3.0 + coll.update(tok1, g1) + + # Batch 2: peaks 0.5, 4.0 for feature 0. Top-4 across the two batches + # should be {4.0, 3.0, 2.0, 1.0}, dropping 0.5. + tok2 = torch.full((2, T), 22, dtype=torch.long) + g2 = _zero_gates(2) + g2[0, 0, 0] = 0.5 + g2[1, 0, 0] = 4.0 + coll.update(tok2, g2) + + summary = coll.feature_summary(0) + assert summary.top_peaks == [4.0, 3.0, 2.0, 1.0] + # Top sequence is from batch 2 (its tokens are all 22). + assert summary.top_token_ids[0] == [22] * T + # The 1.0-peak entry is from batch 1 (tokens all 11). + assert summary.top_token_ids[3] == [11] * T + + +def test_dead_feature_summary_is_empty(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + g = _zero_gates(2) + g[0, 0, 0] = 1.0 # only feature 0 fires + coll.update(torch.zeros(2, T, dtype=torch.long), g) + + dead = coll.feature_summary(2) + assert dead.activation_rate == 0.0 + assert dead.max_activation == 0.0 + assert dead.top_peaks == [] + assert dead.top_token_ids == [] + assert dead.top_activations == [] + + +def test_per_token_trace_is_feature_specific(): + """The trace stored for feature i must be gates[..., i] for that sequence, + not any other feature.""" + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + tok = torch.zeros(1, T, dtype=torch.long) + g = _zero_gates(1) + g[0, :, 0] = torch.arange(T, dtype=torch.float32) # feat 0: 0..T-1 + g[0, :, 1] = torch.arange(T, dtype=torch.float32) * 10 # feat 1: 0,10,20,... + coll.update(tok, g) + + s0 = coll.feature_summary(0) + s1 = coll.feature_summary(1) + assert s0.top_activations[0] == list(range(T)) + assert s1.top_activations[0] == [i * 10 for i in range(T)] + + +def test_update_rejects_wrong_shape(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + bad_g = torch.zeros(1, T + 1, F) + try: + coll.update(torch.zeros(1, T + 1, dtype=torch.long), bad_g) + except ValueError: + return + raise AssertionError("expected ValueError for seq_len mismatch") + + +def test_update_rejects_wrong_feature_count(): + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + bad_g = torch.zeros(1, T, F + 1) + try: + coll.update(torch.zeros(1, T, dtype=torch.long), bad_g) + except ValueError: + return + raise AssertionError("expected ValueError for feature count mismatch") diff --git a/tests/test_feature_dash_dump.py b/tests/test_feature_dash_dump.py new file mode 100644 index 0000000..baf3e63 --- /dev/null +++ b/tests/test_feature_dash_dump.py @@ -0,0 +1,292 @@ +"""Tests for windowing (Step 3) and JSON dump (Step 4). + +The windowing helper is checked with a stub tokenizer so the test stays CPU +and offline. The dump test builds a tiny GateCollector + MoltCheckpointMetadata +by hand, runs `dump_dashboard`, and checks the file layout and JSON schema. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import torch + +from crosslayer_transcoder.feature_dash.collect import ( + FeatureSummary, + GateCollector, + window_example, + window_feature_summary, +) +from crosslayer_transcoder.feature_dash.dump import dump_dashboard +from crosslayer_transcoder.feature_dash.load import MoltCheckpointMetadata + + +class _StubTokenizer: + """Minimal HF-tokenizer-like object: id -> string with a leading space.""" + + def decode(self, ids): + # Match HF behavior of single-id decode for byte-level BPE: returns + # the raw token text including leading whitespace where applicable. + return f" tok{ids[0]}" + + +# ---- Step 3 ----------------------------------------------------------------- + + +def test_window_example_centers_on_peak(): + ids = list(range(10)) # 0..9 + acts = [0.0] * 10 + acts[6] = 5.0 # peak at index 6 + + out = window_example(ids, acts, _StubTokenizer(), window=2) + + # window=2 -> indices [4, 5, 6, 7, 8] = 5 tokens + assert out["tokens"] == [" tok4", " tok5", " tok6", " tok7", " tok8"] + assert out["activations"] == [0.0, 0.0, 5.0, 0.0, 0.0] + assert out["peak_token_pos"] == 2 # index of 6 inside the window + assert out["peak_activation"] == 5.0 + + +def test_window_example_clips_at_left_edge(): + ids = list(range(10)) + acts = [0.0] * 10 + acts[1] = 3.0 # peak near the start + + out = window_example(ids, acts, _StubTokenizer(), window=4) + + # window=4 around pos 1 -> [0..5], length 6, peak at pos 1 + assert out["tokens"] == [f" tok{i}" for i in range(6)] + assert out["peak_token_pos"] == 1 + assert out["peak_activation"] == 3.0 + + +def test_window_example_clips_at_right_edge(): + ids = list(range(10)) + acts = [0.0] * 10 + acts[9] = 7.0 # peak at the end + + out = window_example(ids, acts, _StubTokenizer(), window=3) + + # window=3 around pos 9 -> [6..9], length 4, peak at last pos within window + assert out["tokens"] == [" tok6", " tok7", " tok8", " tok9"] + assert out["peak_token_pos"] == 3 + assert out["peak_activation"] == 7.0 + + +def test_window_example_rejects_mismatched_lengths(): + with pytest.raises(ValueError): + window_example([1, 2, 3], [0.1, 0.2], _StubTokenizer(), window=1) + + +def test_window_feature_summary_passes_through_metadata(): + summary = FeatureSummary( + feature_id=42, + activation_rate=0.05, + max_activation=2.5, + top_peaks=[2.5, 1.0], + top_token_ids=[[10, 11, 12], [20, 21, 22]], + top_activations=[[0.0, 2.5, 0.0], [1.0, 0.0, 0.0]], + ) + out = window_feature_summary(summary, _StubTokenizer(), window=10) + + assert out["feature_id"] == 42 + assert out["activation_rate"] == 0.05 + assert out["max_activation"] == 2.5 + assert len(out["examples"]) == 2 + assert out["examples"][0]["peak_activation"] == 2.5 + assert out["examples"][1]["peak_activation"] == 1.0 + + +# ---- Step 4 ----------------------------------------------------------------- + + +def _populate_collector(F: int, K: int, T: int) -> GateCollector: + coll = GateCollector(n_features=F, top_k=K, seq_len=T) + # feature 0 fires on two sequences with different peaks + tok = torch.arange(2 * T, dtype=torch.long).reshape(2, T) + g = torch.zeros(2, T, F) + g[0, 1, 0] = 1.0 + g[1, 3, 0] = 2.5 + g[0, 0, 1] = 0.5 # feature 1 fires once + coll.update(tok, g) + return coll + + +def _meta(F: int, ranks: list[int], N: int) -> MoltCheckpointMetadata: + feature_tier: list[int] = [] + feature_rank: list[int] = [] + for t, r in enumerate(ranks): + n_in_tier = N * (2**t) + feature_tier.extend([t] * n_in_tier) + feature_rank.extend([r] * n_in_tier) + return MoltCheckpointMetadata( + ckpt_path="dummy.ckpt", + d_acts=8, + n_features=F, + n_layers=2, + ranks=ranks, + N=N, + feature_tier=feature_tier, + feature_rank=feature_rank, + base_model_name="stub", + training_dataset="stub-ds", + global_step=123, + epoch=0, + ) + + +def test_dump_dashboard_writes_metadata_and_features(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + ranks = [4, 2] # tier 0: 1 feat, tier 1: 2 feats -> 3 + coll = _populate_collector(F, K, T) + meta = _meta(F, ranks, N) + + data_dir = dump_dashboard( + collector=coll, + meta=meta, + tokenizer=_StubTokenizer(), + out_dir=tmp_path, + layer=8, + dataset_name="Skylion007/openwebtext", + window=2, + ) + + assert data_dir == tmp_path / "data" + assert (data_dir / "metadata.json").is_file() + + feat_dir = data_dir / "features" + files = sorted(feat_dir.glob("*.json")) + assert len(files) == F + # Filenames are zero-padded to width=4 (max(4, len("2"))). + assert [p.name for p in files] == ["0000.json", "0001.json", "0002.json"] + + md = json.loads((data_dir / "metadata.json").read_text()) + assert md["schema_version"] == 1 + assert md["n_features"] == F + assert md["ranks"] == ranks + assert md["layer"] == 8 + assert md["dashboard_dataset"] == "Skylion007/openwebtext" + assert md["n_tokens_collected"] == 2 * T + assert md["seq_len"] == T + assert md["top_k"] == K + assert md["window"] == 2 + assert md["feature_tier"] == [0, 1, 1] + assert md["feature_rank"] == [4, 2, 2] + # Per-feature aggregates so the index page renders from one fetch. + assert len(md["feature_activation_rate"]) == F + assert len(md["feature_max_activation"]) == F + assert md["feature_activation_rate"][0] == pytest.approx(2 / (2 * T)) + assert md["feature_max_activation"][0] == pytest.approx(2.5) + assert md["feature_activation_rate"][2] == 0.0 # dead feature + assert md["feature_max_activation"][2] == 0.0 # max clamped from -inf to 0 + + +def test_dump_dashboard_copies_render_assets(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + ranks = [4, 2] + coll = _populate_collector(F, K, T) + meta = _meta(F, ranks, N) + + dump_dashboard( + collector=coll, + meta=meta, + tokenizer=_StubTokenizer(), + out_dir=tmp_path, + layer=8, + window=2, + ) + # Static assets land alongside data/. + for rel in [ + "index.html", + "dashboard.html", + "assets/dashboard.css", + "assets/dashboard.js", + "assets/index.js", + ]: + assert (tmp_path / rel).is_file(), f"missing {rel}" + + +def test_dump_dashboard_skips_assets_when_disabled(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + coll = _populate_collector(F, K, T) + meta = _meta(F, [4, 2], N) + + dump_dashboard( + collector=coll, + meta=meta, + tokenizer=_StubTokenizer(), + out_dir=tmp_path, + layer=8, + window=2, + copy_assets=False, + ) + assert not (tmp_path / "index.html").exists() + assert not (tmp_path / "assets").exists() + # data/ still written. + assert (tmp_path / "data" / "metadata.json").is_file() + + +def test_dump_dashboard_feature_payload_schema(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + ranks = [4, 2] + coll = _populate_collector(F, K, T) + meta = _meta(F, ranks, N) + + dump_dashboard( + collector=coll, + meta=meta, + tokenizer=_StubTokenizer(), + out_dir=tmp_path, + layer=8, + window=2, + ) + feat0 = json.loads((tmp_path / "data" / "features" / "0000.json").read_text()) + + # Required schema keys. + assert set(feat0.keys()) >= { + "feature_id", + "tier", + "rank", + "activation_rate", + "max_activation", + "examples", + } + assert feat0["feature_id"] == 0 + assert feat0["tier"] == 0 + assert feat0["rank"] == 4 + assert feat0["activation_rate"] == pytest.approx(2 / (2 * T)) + assert feat0["max_activation"] == pytest.approx(2.5) + + # Two examples, sorted desc by peak. + assert len(feat0["examples"]) == 2 + assert feat0["examples"][0]["peak_activation"] == pytest.approx(2.5) + assert feat0["examples"][1]["peak_activation"] == pytest.approx(1.0) + + ex0 = feat0["examples"][0] + assert len(ex0["tokens"]) == len(ex0["activations"]) + # peak_token_pos must index the argmax of the windowed activations. + assert ( + max(range(len(ex0["activations"])), key=lambda i: ex0["activations"][i]) + == ex0["peak_token_pos"] + ) + + +def test_dump_dashboard_dead_feature_yields_no_examples(tmp_path: Path): + F, K, T, N = 3, 2, 5, 1 + ranks = [4, 2] + coll = _populate_collector(F, K, T) # feature 2 never fires + meta = _meta(F, ranks, N) + + dump_dashboard( + collector=coll, + meta=meta, + tokenizer=_StubTokenizer(), + out_dir=tmp_path, + layer=8, + window=2, + ) + feat2 = json.loads((tmp_path / "data" / "features" / "0002.json").read_text()) + assert feat2["activation_rate"] == 0.0 + assert feat2["examples"] == [] diff --git a/tests/test_feature_dash_integration.py b/tests/test_feature_dash_integration.py new file mode 100644 index 0000000..38cb2ba --- /dev/null +++ b/tests/test_feature_dash_integration.py @@ -0,0 +1,189 @@ +"""Integration smoke test for the feature-dash pipeline. + +Wires together: build tiny MoLT -> save Lightning ckpt -> load_molt -> +GateCollector with synthetic gates -> dump_dashboard -> assert the full file +contract holds. + +We bypass `collect_features` (which would need a real GPT-2 + HF dataset) by +feeding the collector hand-crafted batches. The integration we care about +here is the file pipeline: that the loaded MoLT's metadata flows correctly +into the JSON, and that all the per-feature contracts the renderer relies on +are satisfied. + +CLI smoke (Step 6) is also exercised here via argparse `--help`. +""" + +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +import lightning as L +import pytest +import torch +from transformers import GPT2TokenizerFast + +from crosslayer_transcoder.feature_dash import ( + GateCollector, + dump_dashboard, + load_molt, +) +from crosslayer_transcoder.model.clt_lightning import MoltModule +from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt +from crosslayer_transcoder.model.standardize import ( + DimensionwiseInputStandardizer, + DimensionwiseOutputStandardizer, +) + + +D_ACTS = 16 +N_LAYERS_LM = 4 +N = 2 +RANKS = [4, 2] # tier 0: 2 feats, tier 1: 4 feats -> 6 features total +LAYER = 1 +SEQ_LEN = 8 +TOP_K = 3 + + +def _build_molt() -> Molt: + n_features = N * 1 + N * 2 # 6 + nonlin = JumpReLU(theta=0.05, bandwidth=1.0, n_layers=1, d_features=n_features) + in_std = DimensionwiseInputStandardizer(n_layers=N_LAYERS_LM, activation_dim=D_ACTS) + out_std = DimensionwiseOutputStandardizer( + n_layers=N_LAYERS_LM, activation_dim=D_ACTS + ) + fake_batch = torch.randn(8, 2, N_LAYERS_LM, D_ACTS) + in_std.initialize_from_batch(fake_batch) + out_std.initialize_from_batch(fake_batch) + + return Molt( + d_acts=D_ACTS, + N=N, + ranks=RANKS, + nonlinearity=nonlin, + input_standardizer=in_std, + output_standardizer=out_std, + ) + + +def _save_lightning_checkpoint(tmp_path: Path, molt: Molt) -> str: + module = MoltModule(model=molt) + trainer = L.Trainer( + accelerator="cpu", + devices=1, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + max_steps=0, + ) + trainer.strategy.connect(module) + ckpt = tmp_path / "molt.ckpt" + trainer.save_checkpoint(str(ckpt)) + return str(ckpt) + + +def _populate_collector_with_molt_gates( + molt: Molt, collector: GateCollector, n_batches: int = 4, batch_size: int = 4 +) -> None: + """Run synthetic resid through the loaded MoLT to get realistic gates.""" + torch.manual_seed(0) + for _ in range(n_batches): + # Random residuals; the standardizer will re-center them. + resid = torch.randn(batch_size, SEQ_LEN, D_ACTS) + resid_std = molt.input_standardizer(resid, LAYER) + gates = molt.nonlinearity(molt.e(resid_std)) # (B, T, F) + # Synthetic token ids: just whatever, the renderer doesn't care. + token_ids = torch.randint( + 0, 1000, (batch_size, SEQ_LEN), dtype=torch.long + ) + collector.update(token_ids, gates.float()) + + +def test_full_pipeline_contract(tmp_path: Path): + """Load -> collect -> dump produces a directory the renderer can consume.""" + src_molt = _build_molt() + ckpt_path = _save_lightning_checkpoint(tmp_path, src_molt) + + molt, meta = load_molt(ckpt_path, device="cpu") + assert meta.n_features == src_molt.n_features + + collector = GateCollector( + n_features=meta.n_features, top_k=TOP_K, seq_len=SEQ_LEN + ) + _populate_collector_with_molt_gates(molt, collector) + + tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") + + out_dir = tmp_path / "dash" + dump_dashboard( + collector=collector, + meta=meta, + tokenizer=tokenizer, + out_dir=out_dir, + layer=LAYER, + dataset_name="synthetic", + window=2, + ) + + # --- File layout --- + for rel in [ + "data/metadata.json", + "index.html", + "dashboard.html", + "assets/dashboard.css", + "assets/dashboard.js", + "assets/index.js", + ]: + assert (out_dir / rel).is_file(), f"missing {rel}" + feat_files = sorted((out_dir / "data" / "features").glob("*.json")) + assert len(feat_files) == meta.n_features + + # --- metadata.json invariants the renderer relies on --- + md = json.loads((out_dir / "data" / "metadata.json").read_text()) + assert md["schema_version"] == 1 + assert md["n_features"] == meta.n_features + assert len(md["feature_activation_rate"]) == meta.n_features + assert len(md["feature_max_activation"]) == meta.n_features + assert len(md["feature_tier"]) == meta.n_features + assert len(md["feature_rank"]) == meta.n_features + assert all(0.0 <= r <= 1.0 for r in md["feature_activation_rate"]) + assert all(m >= 0.0 for m in md["feature_max_activation"]) + + # --- per-feature invariants --- + for fp in feat_files: + d = json.loads(fp.read_text()) + assert {"feature_id", "tier", "rank", "activation_rate", + "max_activation", "examples"} <= set(d.keys()) + assert 0.0 <= d["activation_rate"] <= 1.0 + for ex in d["examples"]: + assert {"peak_activation", "peak_token_pos", "tokens", "activations"} <= set(ex.keys()) + assert len(ex["tokens"]) == len(ex["activations"]) + assert 0 <= ex["peak_token_pos"] < len(ex["tokens"]) + # peak_token_pos must index the argmax of the windowed activations. + argmax = max( + range(len(ex["activations"])), + key=lambda i: ex["activations"][i], + ) + assert argmax == ex["peak_token_pos"] + # Peak activation must be > 0 (dead-feature filter in feature_summary). + assert ex["peak_activation"] > 0 + + +def test_cli_help_runs(): + """Make sure argparse construction (and therefore imports) don't blow up.""" + result = subprocess.run( + [sys.executable, "-m", "crosslayer_transcoder.feature_dash", "--help"], + capture_output=True, + text=True, + timeout=60, + ) + assert result.returncode == 0, result.stderr + assert "feature dashboard" in result.stdout.lower() + # Both source flags should appear. + assert "--hf-filename" in result.stdout + assert "--local-ckpt" in result.stdout + assert "--out" in result.stdout diff --git a/tests/test_feature_dash_load.py b/tests/test_feature_dash_load.py new file mode 100644 index 0000000..9ceb59a --- /dev/null +++ b/tests/test_feature_dash_load.py @@ -0,0 +1,124 @@ +"""Smoke test for the MoLT feature-dash checkpoint loader. + +Builds a tiny MoLT, wraps it in a MoltModule, saves a Lightning checkpoint via +trainer.save_checkpoint, then loads it back with `load_molt` and checks: + + - architecture is recovered (d_acts, n_features, n_layers, ranks, N), + - standardizer buffers round-trip, + - JumpReLU theta round-trips, + - forward output matches the source model bit-for-bit on a fixed input, + - the per-feature tier/rank index lines up with the rank tier sizes. +""" + +from __future__ import annotations + +import lightning as L +import pytest +import torch + +from crosslayer_transcoder.feature_dash.load import infer_molt_arch, load_molt +from crosslayer_transcoder.model.clt_lightning import MoltModule +from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt +from crosslayer_transcoder.model.standardize import ( + DimensionwiseInputStandardizer, + DimensionwiseOutputStandardizer, +) + + +D_ACTS = 16 +N_LAYERS = 4 +N = 3 +RANKS = [8, 4] # tier 0: 3 transforms, tier 1: 6 transforms -> 9 features +LAYER = 2 + + +def _build_molt() -> Molt: + n_features = N * 1 + N * 2 # = 9 + nonlin = JumpReLU(theta=0.05, bandwidth=1.0, n_layers=1, d_features=n_features) + in_std = DimensionwiseInputStandardizer(n_layers=N_LAYERS, activation_dim=D_ACTS) + out_std = DimensionwiseOutputStandardizer(n_layers=N_LAYERS, activation_dim=D_ACTS) + + fake_batch = torch.randn(8, 2, N_LAYERS, D_ACTS) + in_std.initialize_from_batch(fake_batch) + out_std.initialize_from_batch(fake_batch) + + return Molt( + d_acts=D_ACTS, + N=N, + ranks=RANKS, + nonlinearity=nonlin, + input_standardizer=in_std, + output_standardizer=out_std, + ) + + +def _save_lightning_checkpoint(tmp_path, molt: Molt) -> str: + module = MoltModule(model=molt) + trainer = L.Trainer( + accelerator="cpu", + devices=1, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + max_steps=0, + ) + # `strategy.connect` wires the module so save_checkpoint works without + # actually running `fit`. + trainer.strategy.connect(module) + ckpt_path = tmp_path / "molt.ckpt" + trainer.save_checkpoint(str(ckpt_path)) + return str(ckpt_path) + + +def test_infer_molt_arch_from_state_dict(): + molt = _build_molt() + sd = {f"model.{k}": v for k, v in molt.state_dict().items()} + # mimic what Lightning would add + sd["last_active"] = torch.zeros((molt.n_features,), dtype=torch.long) + + arch = infer_molt_arch(sd) + assert arch == { + "d_acts": D_ACTS, + "n_features": N * 1 + N * 2, + "n_layers": N_LAYERS, + "ranks": RANKS, + "N": N, + } + + +def test_load_molt_roundtrip(tmp_path): + torch.manual_seed(0) + src = _build_molt() + ckpt_path = _save_lightning_checkpoint(tmp_path, src) + + loaded, meta = load_molt(ckpt_path, device="cpu") + + assert meta.d_acts == D_ACTS + assert meta.n_features == src.n_features + assert meta.ranks == RANKS + assert meta.N == N + assert meta.n_layers == N_LAYERS + + # Tier index has the right tier counts. + assert meta.feature_tier.count(0) == N + assert meta.feature_tier.count(1) == 2 * N + assert meta.feature_rank[0] == RANKS[0] + assert meta.feature_rank[-1] == RANKS[1] + assert len(meta.feature_tier) == src.n_features + + # Forward output matches bit-for-bit on a fixed input. + src.eval() + loaded.eval() + x = torch.randn(5, D_ACTS) + with torch.no_grad(): + g_src, rn_src, r_src = src(x, layer=LAYER) + g_loaded, rn_loaded, r_loaded = loaded(x, layer=LAYER) + assert torch.allclose(g_src, g_loaded) + assert torch.allclose(rn_src, rn_loaded) + assert torch.allclose(r_src, r_loaded) + + # Loaded model must be in eval and have grads disabled. + assert not loaded.training + assert all(not p.requires_grad for p in loaded.parameters()) diff --git a/tests/test_molt_smoke.py b/tests/test_molt_smoke.py new file mode 100644 index 0000000..ebb28f6 --- /dev/null +++ b/tests/test_molt_smoke.py @@ -0,0 +1,110 @@ +"""Smoke tests for the MoLT (Mixture of Low-rank Transcoders) port. + +CPU test exercises the model wiring (low-rank transforms + JumpReLU + standardizers). +GPU test runs a fp32 and a 16-mixed forward+backward+optimizer step on synthetic +activations and checks that all parameters and the loss remain finite — mirrors +what the full-config CLI smoke covers, without the data generator or wandb. +""" + +import pytest +import torch + +from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt +from crosslayer_transcoder.model.standardize import ( + DimensionwiseInputStandardizer, + DimensionwiseOutputStandardizer, +) + + +D_ACTS = 64 +N_LAYERS = 12 +N = 4 +RANKS = [8, 4] +B = 4 + + +def _build_molt(device: torch.device) -> Molt: + n_features = N + 2 * N + nonlin = JumpReLU(theta=0.03, bandwidth=1.0, n_layers=1, d_features=n_features) + in_std = DimensionwiseInputStandardizer(n_layers=N_LAYERS, activation_dim=D_ACTS) + out_std = DimensionwiseOutputStandardizer(n_layers=N_LAYERS, activation_dim=D_ACTS) + + fake_batch = torch.randn(B, 2, N_LAYERS, D_ACTS) + in_std.initialize_from_batch(fake_batch) + out_std.initialize_from_batch(fake_batch) + + return Molt( + d_acts=D_ACTS, + N=N, + ranks=RANKS, + nonlinearity=nonlin, + input_standardizer=in_std, + output_standardizer=out_std, + ).to(device) + + +def test_molt_cpu_forward(): + torch.manual_seed(0) + m = _build_molt(torch.device("cpu")) + acts = torch.randn(B, D_ACTS) + + gate, recons_norm, recons = m(acts, layer=8) + + assert gate.shape == (B, m.n_features) + assert recons_norm.shape == (B, D_ACTS) + assert recons.shape == (B, D_ACTS) + assert torch.isfinite(gate).all() + assert torch.isfinite(recons).all() + + +def _run_train_step(device, autocast_dtype): + """Forward + backward + optimizer step. Mirrors MoltModule.training_step + without depending on the Lightning Trainer.""" + torch.manual_seed(0) + m = _build_molt(device) + optim = torch.optim.Adam(m.parameters(), lr=2e-4) + scaler = torch.amp.GradScaler("cuda", enabled=autocast_dtype is torch.float16) + + resid = torch.randn(B, D_ACTS, device=device) + mlp_out = torch.randn(B, D_ACTS, device=device) + + optim.zero_grad(set_to_none=True) + if autocast_dtype is None: + gate, recons_norm, _ = m(resid, layer=8) + target = m.output_standardizer.standardize(mlp_out, 8) + mse = ((recons_norm - target) ** 2).mean() + norms = m.transform_norm() + sparsity = torch.tanh(norms * gate * 100.0).sum(dim=-1).mean() * 1.5e-4 + loss = mse + sparsity + loss.backward() + optim.step() + else: + with torch.amp.autocast("cuda", dtype=autocast_dtype): + gate, recons_norm, _ = m(resid, layer=8) + target = m.output_standardizer.standardize(mlp_out, 8) + mse = ((recons_norm - target) ** 2).mean() + norms = m.transform_norm() + sparsity = torch.tanh(norms * gate * 100.0).sum(dim=-1).mean() * 1.5e-4 + loss = mse + sparsity + scaler.scale(loss).backward() + scaler.step(optim) + scaler.update() + + return loss, m + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_molt_gpu_fp32_train_step(): + loss, m = _run_train_step(torch.device("cuda"), autocast_dtype=None) + assert torch.isfinite(loss), f"non-finite loss: {loss.item()}" + for name, p in m.named_parameters(): + assert torch.isfinite(p).all(), f"non-finite param after step: {name}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_molt_gpu_amp_train_step(): + loss, m = _run_train_step(torch.device("cuda"), autocast_dtype=torch.float16) + assert torch.isfinite(loss), f"non-finite loss: {loss.item()}" + for name, p in m.named_parameters(): + assert torch.isfinite(p).all(), f"non-finite param after AMP step: {name}"