diff --git a/config/molt-5090.yaml b/config/molt-5090.yaml new file mode 100644 index 0000000..e29ade4 --- /dev/null +++ b/config/molt-5090.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_0005-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0005 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_00001.yaml b/config/molt-5090_20M_tokens_0_00001.yaml new file mode 100644 index 0000000..118e282 --- /dev/null +++ b/config/molt-5090_20M_tokens_0_00001.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00001-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00001" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00001 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_0001.yaml b/config/molt-5090_20M_tokens_0_0001.yaml new file mode 100644 index 0000000..e267c9e --- /dev/null +++ b/config/molt-5090_20M_tokens_0_0001.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_0001-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_0001" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0001 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_00015.yaml b/config/molt-5090_20M_tokens_0_00015.yaml new file mode 100644 index 0000000..0700010 --- /dev/null +++ b/config/molt-5090_20M_tokens_0_00015.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00015-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00015" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00015 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_20M_tokens_0_0005.yaml b/config/molt-5090_20M_tokens_0_0005.yaml new file mode 100644 index 0000000..d8f9330 --- /dev/null +++ b/config/molt-5090_20M_tokens_0_0005.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 20_000 # 20M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_0005-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_0005" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0005 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_40M_tokens_0_00015_N=100.yaml b/config/molt-5090_40M_tokens_0_00015_N=100.yaml new file mode 100644 index 0000000..2d94a92 --- /dev/null +++ b/config/molt-5090_40M_tokens_0_00015_N=100.yaml @@ -0,0 +1,151 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Doubles N (50 -> 100) and max_steps (20k -> 40k) vs. molt-5090_20M_tokens_0_00015.yaml. +# +# Why: per the MOLT bulk-update post (transformer-circuits.pub/2025/bulk-update), +# training steps should scale proportionally to the number of features, and each 4x +# FLOPs increase comes from 2x params + 2x steps. Doubling N exactly doubles features +# (1550 -> 3100), latents (128k -> 256k), and params (~198M -> ~396M); doubling +# max_steps gives 40M tokens at batch_size=1000 -> 4x FLOPs vs. the 20M baseline. + +seed_everything: 42 + +trainer: + max_steps: 40_000 # 40M tokens at batch_size=1000 + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 2 # mini-batch=500, effective batch=1000 (matches N=50 baseline) + # Memory tuning to fit N=100 (~395M params) on the 32 GB 5090 alongside the + # GPU-resident GPT-2 data generator (~8 GiB): + # - bf16-mixed: halves activation/weight memory in the training pass. + # - mini-batch=500 + accumulate=2: peak forward activation `raw_recons` shrinks + # from [1000,3100,768] (9.5 GiB fp32 / 4.75 GiB bf16) to [500,3100,768] + # (~2.4 GiB bf16). Effective batch (1000), total tokens (40M), optimizer-step + # count (max_steps=40_000), and the linear sparsity ramp (keyed on + # global_step/max_steps) are all preserved vs. the original plan. + # Note: bf16 is a numerical-regime change vs. the float32 N=50 baseline; the + # 4× FLOPs comparison is approximate at the bit level but exact at the math level. + precision: "bf16-mixed" + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00015-3100-256k-jumprelu-bf16-40M-N100" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00015_N100_40M" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 100 # doubled from 50 -> features 1550 -> 3100, latents 128k -> 256k, params ~198M -> ~396M + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 3100 # = 31 * N (sum of N*mult across the 5 rank levels) + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00015 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 500 # halved from 1000; trainer.accumulate_grad_batches=2 restores effective batch=1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-5090_50M_tokens_0_00015.yaml b/config/molt-5090_50M_tokens_0_00015.yaml new file mode 100644 index 0000000..7c8fb8d --- /dev/null +++ b/config/molt-5090_50M_tokens_0_00015.yaml @@ -0,0 +1,137 @@ +# Single-GPU configuration for RTX 5090 (32 GB VRAM, 31 GB /dev/shm). +# Derived from molt.yaml with three classes of changes: +# 1. All components moved to cuda:0 (only one GPU available). +# 2. buffer_size shrunk so the activation ring buffer fits in /dev/shm. + +seed_everything: 42 + +trainer: + max_steps: 50_000 # 50M tokens + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + accelerator: "gpu" + devices: [0] # Only one GPU on this machine + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-5090-0_00015-1550-128k-jumprelu-float32-50M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints/lam_0_00015_50M" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Validation is gated to never fire (val_check_interval=1e9), so this metric + # would never run anyway. + replacement_model: null + + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.00015 + c_sparsity: 100 + use_tanh: true + + # NOTE: For MoltModule these settings are inert. MOLT's dead-feature path + # is hardcoded — it logs `metrics/dead_features` every step via + # update_dead_features(gate) regardless of these flags. The DeadFeatures + # torchmetrics instance configured above is also unused for MOLT. + # TODO: wire MoltModule.training_step into the configurable path. + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + # 350k * 12 layers * 2 (in/out) * 768 dim * 4 bytes ≈ 25.8 GB — fits in /dev/shm (31 GB cap). + # Original config used 1_000_000 which would need ~74 GB. + buffer_size: 350_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + generation_batch_size: 10 + refresh_interval: 0.1 + + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 + + use_shared_memory: true + + device_map: "cuda:0" # was cuda:0 already — unchanged + deployment_policy: "gpu_only" + + wandb_logging: + enabled: true + project: "debug-molt" + group: null + run_name: "data-generator" + tags: ["data-generation"] + save_dir: "./wandb" + log_interval: 5.0 + +ckpt_path: null diff --git a/config/molt-long.yaml b/config/molt-long.yaml new file mode 100644 index 0000000..ebfd470 --- /dev/null +++ b/config/molt-long.yaml @@ -0,0 +1,150 @@ +# Default configuration for CrossLayer Transcoder training +# This file uses Lightning CLI's automatic class construction + +seed_everything: 42 + +trainer: + # max_steps is number of gradient updates, not number of batches + max_steps: 200_000 # 20M tokens + #limit_train_batches: 10_000 + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + precision: "16-mixed" + accelerator: "gpu" + devices: [1] # 4; [1] for cuda:1 + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-0_01" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + #- class_path: crosslayer_transcoder.utils.callbacks.TensorBoardProfilerCallback + # init_args: + # log_dir: "log/profiler" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + # Pre-constructed CrossLayerTranscoder model + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + # Model architecture parameters + d_acts: 768 + n_transforms: 3000 + d_transform: 32 + nonlinearity: + class_path: torch.nn.ReLU + + # Pre-constructed standardizers + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Pre-constructed replacement model + replacement_model: + class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy + init_args: + model_name: "openai-community/gpt2" + device_map: "cuda:1" + loader_batch_size: 2 + + # Pre-constructed dead features metric + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + + # Training parameters + learning_rate: 1e-4 + compile: false + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.01 + c_sparsity: 1 + use_tanh: true + + # Dead features computation settings + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + buffer_size: 1_000_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float16" + max_batch_size: 50000 + + # Model settings for activation generation + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + # Generation settings + generation_batch_size: 10 + refresh_interval: 0.1 + + # Memory settings + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + # File paths + init_file: "/var/local/glang/activations/clt-activations-10M-shuffled_fp16.h5" + + # DataLoader settings + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 # Only provide activations when buffer is at least 20% full + + use_shared_memory: true + + # Device configuration + device_map: "cuda:0" # "cpu", "auto", "cuda:0", "cuda:0,1,2,3" + deployment_policy: "gpu_only" # "cpu_only", "gpu_only", or "dynamic" + + # WandB logging configuration for data generation + wandb_logging: + enabled: true # Enable WandB logging for data generation + project: "debug-molt" # WandB project (should match trainer logger) + group: null # Group name (null = auto-generated from training run) + run_name: "data-generator" # Run name suffix + tags: ["data-generation"] # Tags for the data generation run + save_dir: "./wandb" # Directory for WandB files + log_interval: 5.0 # Logging interval in seconds + +ckpt_path: null \ No newline at end of file diff --git a/config/molt.yaml b/config/molt.yaml new file mode 100644 index 0000000..9ce945a --- /dev/null +++ b/config/molt.yaml @@ -0,0 +1,154 @@ +# Default configuration for CrossLayer Transcoder training +# This file uses Lightning CLI's automatic class construction + +seed_everything: 42 + +trainer: + # max_steps is number of gradient updates, not number of batches + max_steps: 20_000 # 20M tokens + #limit_train_batches: 10_000 + val_check_interval: 1000000000 # replacement model for single layer doesn't make sense + limit_val_batches: 1 + check_val_every_n_epoch: null + enable_checkpointing: false # We use custom end-of-training checkpoint + num_sanity_val_steps: 0 # Can't run replacement model before standardizers are initialized + # precision: "16-mixed" + accelerator: "gpu" + devices: [2] # 4; [1] for cuda:1 + accumulate_grad_batches: 1 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: "debug-molt" + name: "molt-0_0005-1550-128k-jumprelu-float32-20M" + save_dir: "./wandb" + callbacks: + - class_path: crosslayer_transcoder.utils.callbacks.EndOfTrainingCheckpointCallback + init_args: + checkpoint_dir: "checkpoints" + #- class_path: crosslayer_transcoder.utils.callbacks.TensorBoardProfilerCallback + # init_args: + # log_dir: "log/profiler" + +model: + class_path: crosslayer_transcoder.model.clt_lightning.MoltModule + init_args: + # Pre-constructed CrossLayerTranscoder model + model: + class_path: crosslayer_transcoder.model.molt.Molt + init_args: + # Model architecture parameters + d_acts: 768 + N: 50 + nonlinearity: + class_path: crosslayer_transcoder.model.jumprelu.JumpReLU + init_args: + theta: 0.03 + bandwidth: 1.0 + n_layers: 1 + d_features: 1550 + + # Pre-constructed standardizers + input_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + output_standardizer: + class_path: crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer + init_args: + n_layers: 12 + activation_dim: 768 + + # Pre-constructed replacement model + replacement_model: + class_path: crosslayer_transcoder.metrics.replacement_model_accuracy.ReplacementModelAccuracy + init_args: + model_name: "openai-community/gpt2" + device_map: "cuda:1" + loader_batch_size: 2 + + # Pre-constructed dead features metric + dead_features: + class_path: crosslayer_transcoder.metrics.dead_features.DeadFeatures + init_args: + n_features: 120 + n_layers: 12 + return_per_layer: true + return_log_freqs: true + return_neuron_indices: true + + + # Training parameters + learning_rate: 2e-4 + compile: true + lr_decay_step: 1000000000 + lr_decay_factor: 0.1 + + lambda_sparsity: 0.0005 + c_sparsity: 100 + use_tanh: true + + # Dead features computation settings + compute_dead_features: true + compute_dead_features_every: 500 + +data: + class_path: crosslayer_transcoder.data.datamodule.ActivationDataModule + init_args: + # Buffer settings + buffer_size: 1_000_000 + n_in_out: 2 + n_layers: 12 + activation_dim: 768 + dtype: "float32" + max_batch_size: 50000 + + # Model settings for activation generation + model_name: "openai-community/gpt2" + model_dtype: "float32" + + # Dataset settings + dataset_name: "Skylion007/openwebtext" + dataset_split: "train" + max_sequence_length: 1024 + + # Generation settings + generation_batch_size: 10 + refresh_interval: 0.1 + + # Memory settings + shared_memory_name: "activation_buffer" + timeout_seconds: 30 + + # File paths + # init_file: "/var/local/glang/activations/clt-activations-10M-shuffled_fp16.h5" + + # DataLoader settings + batch_size: 1000 + num_workers: 10 + prefetch_factor: 2 + shuffle: true + persistent_workers: true + pin_memory: true + + minimum_fill_threshold: 0.01 # Only provide activations when buffer is at least 20% full + + use_shared_memory: true + + # Device configuration + device_map: "cuda:0" # "cpu", "auto", "cuda:0", "cuda:0,1,2,3" + deployment_policy: "gpu_only" # "cpu_only", "gpu_only", or "dynamic" + + # WandB logging configuration for data generation + wandb_logging: + enabled: true # Enable WandB logging for data generation + project: "debug-molt" # WandB project (should match trainer logger) + group: null # Group name (null = auto-generated from training run) + run_name: "data-generator" # Run name suffix + tags: ["data-generation"] # Tags for the data generation run + save_dir: "./wandb" # Directory for WandB files + log_interval: 5.0 # Logging interval in seconds + +ckpt_path: null \ No newline at end of file diff --git a/crosslayer_transcoder/data/datamodule.py b/crosslayer_transcoder/data/datamodule.py index 5df6778..20b91ad 100644 --- a/crosslayer_transcoder/data/datamodule.py +++ b/crosslayer_transcoder/data/datamodule.py @@ -336,7 +336,7 @@ def teardown(self, stage: str = None): """Clean up resources.""" logger.info("Cleaning up activation data loader...") - if self.data_loader and hasattr(self.data_loader, "cleanup"): + if self.data_loader is not None and hasattr(self.data_loader, "cleanup"): self.data_loader.cleanup() if self.data_generator and self.data_generator.is_alive(): diff --git a/crosslayer_transcoder/model/__init__.py b/crosslayer_transcoder/model/__init__.py index e4f0aae..6a88618 100644 --- a/crosslayer_transcoder/model/__init__.py +++ b/crosslayer_transcoder/model/__init__.py @@ -4,11 +4,13 @@ from .clt import CrossLayerTranscoder from .clt_lightning import CrossLayerTranscoderModule +from .molt import Molt from .topk import BatchTopK, PerLayerBatchTopK, PerLayerTopK __all__ = [ "CrossLayerTranscoder", "CrossLayerTranscoderModule", + "Molt", "BatchTopK", "PerLayerTopK", "PerLayerBatchTopK", diff --git a/crosslayer_transcoder/model/clt_lightning.py b/crosslayer_transcoder/model/clt_lightning.py index 8e4a020..2256028 100644 --- a/crosslayer_transcoder/model/clt_lightning.py +++ b/crosslayer_transcoder/model/clt_lightning.py @@ -2,7 +2,7 @@ import os import subprocess import time -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import lightning as L import psutil @@ -23,6 +23,7 @@ Decoder, ) from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt from crosslayer_transcoder.model.topk import BatchTopK @@ -30,7 +31,7 @@ class CrossLayerTranscoderModule(L.LightningModule): def __init__( self, # Pre-constructed modules - model: CrossLayerTranscoder, + model: Union[CrossLayerTranscoder, Molt], replacement_model: Optional[ReplacementModelAccuracy] = None, dead_features: Optional[DeadFeatures] = None, # Training parameters @@ -85,17 +86,23 @@ def __init__( self.beta2 = beta2 self.log_metrics_every = log_metrics_every - assert self.model.encoder.n_layers == self.model.decoder.n_layers, ( - "Encoder and decoder must have the same number of layers" - ) + if isinstance(self.model, Molt): + self.register_buffer( + "last_active", + torch.zeros((self.model.n_features,), dtype=torch.long), + ) + else: + assert self.model.encoder.n_layers == self.model.decoder.n_layers, ( + "Encoder and decoder must have the same number of layers" + ) - self.register_buffer( - "last_active", - torch.zeros( - (self.model.encoder.n_layers, self.model.encoder.d_features), - dtype=torch.long, - ), - ) + self.register_buffer( + "last_active", + torch.zeros( + (self.model.encoder.n_layers, self.model.encoder.d_features), + dtype=torch.long, + ), + ) def configure_model(self): # Apply compilation if requested @@ -565,3 +572,68 @@ def training_step(self, batch, batch_idx): torch.cuda.memory._record_memory_history(enabled=None) exit() return loss + + +class MoltModule(CrossLayerTranscoderModule): + def __init__( + self, + lambda_sparsity: float = 0.0002, + c_sparsity: float = 0.1, + use_tanh: bool = True, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._lambda = lambda_sparsity + self.c = c_sparsity + self.use_tanh = use_tanh + + def current_sparsity_penalty(self): + n_steps = self.trainer.max_steps + current_step = ( + self.global_step + ) # use global step instead of batch idx to work with gradient accumulation + cur_lambda = self._lambda * (current_step / n_steps) + self.log("training/sparsity_penalty", cur_lambda) + return cur_lambda + + def forward(self, batch, layer): + return self.model.forward(batch, layer) + + def training_step(self, batch, batch_idx): + if batch_idx == 0: + self.model.initialize_standardizers(batch) + self.log("model/d_latents", self.model.d_latents) + self.log("model/n_features", self.model.n_features) + + layer = 8 + + # Forward pass + resid, mlp_out = batch[:, 0], batch[:, 1] + resid = resid[:, layer] + mlp_out = mlp_out[:, layer] + gate, recons_norm, recons = self.model.forward(resid, layer) + + self.update_dead_features(gate) + # Compute MSE loss + mse = (recons_norm - self.model.output_standardizer.standardize(mlp_out, layer)) ** 2 + + # Compute Sparsity Loss + norms = self.model.transform_norm() + weighted_norms = norms * gate + self.log("model/weighted_norms_mean", weighted_norms.detach().mean().cpu()) + + if self.use_tanh: + weighted_norms = torch.tanh(weighted_norms * self.c) + sparsity = self.current_sparsity_penalty() * weighted_norms.sum(dim=-1).mean() + self.log("training/sparsity_loss", sparsity) + self.log("L0", (gate > 0.0).float().sum() / gate.shape[0]) + + loss = mse.mean() + sparsity + self.log("training/mse", mse.mean()) + self.log("training/loss", loss) + + if batch_idx % self.log_metrics_every == 0: + pass + + return loss diff --git a/crosslayer_transcoder/model/jumprelu.py b/crosslayer_transcoder/model/jumprelu.py index d21b7ec..c977830 100644 --- a/crosslayer_transcoder/model/jumprelu.py +++ b/crosslayer_transcoder/model/jumprelu.py @@ -52,7 +52,8 @@ def backward(ctx, grad_output): class JumpReLU(SerializableModule): def __init__(self, theta=0.0, bandwidth=1.0, n_layers=12, d_features=768 * 8): super().__init__() - self.theta = nn.Parameter(torch.full((1, n_layers, d_features), theta)) + shape = (1, n_layers, d_features) if n_layers > 1 else (1, d_features) + self.theta = nn.Parameter(torch.full(shape, theta)) self.register_buffer("bandwidth", torch.tensor(bandwidth)) self._init_theta = theta self.n_layers = n_layers diff --git a/crosslayer_transcoder/model/molt.py b/crosslayer_transcoder/model/molt.py new file mode 100644 index 0000000..afb3fd3 --- /dev/null +++ b/crosslayer_transcoder/model/molt.py @@ -0,0 +1,93 @@ +import einops +import torch +import torch.nn as nn +from jaxtyping import Float + + +class Molt(nn.Module): + def __init__( + self, + d_acts: int, + N: int, + nonlinearity: nn.Module, + input_standardizer: nn.Module, + output_standardizer: nn.Module, + ranks: list[int] = [512, 256, 128, 64, 32], + ): + super().__init__() + + self.d_acts = d_acts + self.nonlinearity = nonlinearity + self.input_standardizer = input_standardizer + self.output_standardizer = output_standardizer + Us = [] + Vs = [] + rank_multiplier = 1 + n_features = 0 + d_latents = 0 + for rank in ranks: + Us.append(nn.Parameter(torch.empty(N * rank_multiplier, rank, d_acts))) + Vs.append(nn.Parameter(torch.empty(N * rank_multiplier, d_acts, rank))) + n_features += N * rank_multiplier + d_latents += N * rank_multiplier * rank + rank_multiplier *= 2 + self.n_features = n_features + self.e = nn.Linear(d_acts, n_features) + self.Us = nn.ParameterList(Us) + self.Vs = nn.ParameterList(Vs) + + print(f"d_latents (transcoder equivalent): {d_latents}") + self.d_latents = d_latents + + self.reset_parameters() + + def reset_parameters(self): + for U in self.Us: + nn.init.xavier_uniform_(U) + for V in self.Vs: + nn.init.xavier_uniform_(V) + + def transform_norm(self): + norms = [] + for U, V in zip(self.Us, self.Vs): + uv = einops.einsum( + U, + V, + "n_transforms d_transform d_acts_out, n_transforms d_acts_in d_transform -> n_transforms d_acts_in d_acts_out", + ) + norms.append(torch.norm(uv, dim=(1, 2))) + return torch.cat(norms, dim=0) + + def forward( + self, acts: Float[torch.Tensor, "batch_size d_acts"], layer: int + ) -> Float[torch.Tensor, "batch_size d_acts"]: + acts = self.input_standardizer(acts, layer) + pre_actvs = self.e(acts) + gate = self.nonlinearity(pre_actvs) # (batch, n_transforms) + + raw_recons = [] + for U, V in zip(self.Us, self.Vs): + latents = einops.einsum( + acts, + V, + "batch d_acts, n_transforms d_acts d_transform -> batch n_transforms d_transform", + ) + raw_recons.append( + einops.einsum( + latents, + U, + "batch n_transforms d_transform, n_transforms d_transform d_acts -> batch n_transforms d_acts", + ) + ) + + raw_recons = torch.cat(raw_recons, dim=1) + + weighted_recons = gate.unsqueeze(-1) * raw_recons + recons_norm = weighted_recons.sum(dim=1) + + recons = self.output_standardizer(recons_norm, layer) + return gate, recons_norm, recons + + def initialize_standardizers(self, batch: Float[torch.Tensor, "batch_size io n_layers d_acts"]): + self.input_standardizer.initialize_from_batch(batch) + self.output_standardizer.initialize_from_batch(batch) diff --git a/tests/test_molt_smoke.py b/tests/test_molt_smoke.py new file mode 100644 index 0000000..ebb28f6 --- /dev/null +++ b/tests/test_molt_smoke.py @@ -0,0 +1,110 @@ +"""Smoke tests for the MoLT (Mixture of Low-rank Transcoders) port. + +CPU test exercises the model wiring (low-rank transforms + JumpReLU + standardizers). +GPU test runs a fp32 and a 16-mixed forward+backward+optimizer step on synthetic +activations and checks that all parameters and the loss remain finite — mirrors +what the full-config CLI smoke covers, without the data generator or wandb. +""" + +import pytest +import torch + +from crosslayer_transcoder.model.jumprelu import JumpReLU +from crosslayer_transcoder.model.molt import Molt +from crosslayer_transcoder.model.standardize import ( + DimensionwiseInputStandardizer, + DimensionwiseOutputStandardizer, +) + + +D_ACTS = 64 +N_LAYERS = 12 +N = 4 +RANKS = [8, 4] +B = 4 + + +def _build_molt(device: torch.device) -> Molt: + n_features = N + 2 * N + nonlin = JumpReLU(theta=0.03, bandwidth=1.0, n_layers=1, d_features=n_features) + in_std = DimensionwiseInputStandardizer(n_layers=N_LAYERS, activation_dim=D_ACTS) + out_std = DimensionwiseOutputStandardizer(n_layers=N_LAYERS, activation_dim=D_ACTS) + + fake_batch = torch.randn(B, 2, N_LAYERS, D_ACTS) + in_std.initialize_from_batch(fake_batch) + out_std.initialize_from_batch(fake_batch) + + return Molt( + d_acts=D_ACTS, + N=N, + ranks=RANKS, + nonlinearity=nonlin, + input_standardizer=in_std, + output_standardizer=out_std, + ).to(device) + + +def test_molt_cpu_forward(): + torch.manual_seed(0) + m = _build_molt(torch.device("cpu")) + acts = torch.randn(B, D_ACTS) + + gate, recons_norm, recons = m(acts, layer=8) + + assert gate.shape == (B, m.n_features) + assert recons_norm.shape == (B, D_ACTS) + assert recons.shape == (B, D_ACTS) + assert torch.isfinite(gate).all() + assert torch.isfinite(recons).all() + + +def _run_train_step(device, autocast_dtype): + """Forward + backward + optimizer step. Mirrors MoltModule.training_step + without depending on the Lightning Trainer.""" + torch.manual_seed(0) + m = _build_molt(device) + optim = torch.optim.Adam(m.parameters(), lr=2e-4) + scaler = torch.amp.GradScaler("cuda", enabled=autocast_dtype is torch.float16) + + resid = torch.randn(B, D_ACTS, device=device) + mlp_out = torch.randn(B, D_ACTS, device=device) + + optim.zero_grad(set_to_none=True) + if autocast_dtype is None: + gate, recons_norm, _ = m(resid, layer=8) + target = m.output_standardizer.standardize(mlp_out, 8) + mse = ((recons_norm - target) ** 2).mean() + norms = m.transform_norm() + sparsity = torch.tanh(norms * gate * 100.0).sum(dim=-1).mean() * 1.5e-4 + loss = mse + sparsity + loss.backward() + optim.step() + else: + with torch.amp.autocast("cuda", dtype=autocast_dtype): + gate, recons_norm, _ = m(resid, layer=8) + target = m.output_standardizer.standardize(mlp_out, 8) + mse = ((recons_norm - target) ** 2).mean() + norms = m.transform_norm() + sparsity = torch.tanh(norms * gate * 100.0).sum(dim=-1).mean() * 1.5e-4 + loss = mse + sparsity + scaler.scale(loss).backward() + scaler.step(optim) + scaler.update() + + return loss, m + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_molt_gpu_fp32_train_step(): + loss, m = _run_train_step(torch.device("cuda"), autocast_dtype=None) + assert torch.isfinite(loss), f"non-finite loss: {loss.item()}" + for name, p in m.named_parameters(): + assert torch.isfinite(p).all(), f"non-finite param after step: {name}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_molt_gpu_amp_train_step(): + loss, m = _run_train_step(torch.device("cuda"), autocast_dtype=torch.float16) + assert torch.isfinite(loss), f"non-finite loss: {loss.item()}" + for name, p in m.named_parameters(): + assert torch.isfinite(p).all(), f"non-finite param after AMP step: {name}"