From 85d8026ef86d9cccdb058b87f6d5bd7855aa6c2f Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Wed, 24 Sep 2025 10:55:43 -0500 Subject: [PATCH 1/4] Add configs for seqlen 128 for mla and mla-o --- .../configs/seqlen128_mla-on_q96_k64_o96.json | 104 ++++++++++++++++++ .../configs/seqlen128_mla_q96_k64.json | 104 ++++++++++++++++++ 2 files changed, 208 insertions(+) create mode 100644 subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json create mode 100644 subspace_decoder/configs/seqlen128_mla_q96_k64.json diff --git a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json new file mode 100644 index 0000000..b53cfb8 --- /dev/null +++ b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json @@ -0,0 +1,104 @@ +{ + "shorthand": "seqlen.128 - mla-on.96.64.96 - mlp.1024 - model.256.lyr.6 - ah.8.32", + "notes": "MLA-o with sequential decomp and norm. Sequence length is 128.", + "model": { + "hidden_size": 256, + "num_hidden_layers": 6, + "num_nextn_predict_layers": 1, + "moe_intermediate_size": 128, + "intermediate_size": 1024, + "n_shared_experts": 1, + "n_routed_experts": 4, + "ep_size": 1, + "routed_scaling_factor": 1, + "num_experts_per_tok": 2, + "moe_layer_freq": 2, + "first_k_dense_replace": 1, + "topk_method": "softmax_aux", + "n_group": 1, + "topk_group": 1, + "norm_topk_prob": true, + "scoring_func": "softmax", + "hidden_act": "silu", + "use_cache": false, + "pad_token_id": 50256, + "bos_token_id": 50256, + "eos_token_id": 50256, + "tie_word_embeddings": true, + "attention_dropout": 0.0, + "hidden_dropout_prob": 0.1, + "classifier_dropout": null, + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "vocab_size": 50257, + "rope_theta": 10000.0, + "rope_scaling": null, + "max_position_embeddings": 1024, + "kv_lora_rank": 64, + "q_lora_rank": 96, + "qk_rope_head_dim": 32, + "v_head_dim": 32, + "qk_nope_head_dim": 0, + "num_attention_heads": 8, + "num_key_value_heads": 8, + "attention_bias": false, + "use_output_subspace": true, + "o_proj_variant": "sequential_norm", + "o_latent_dim": 96, + "attention_backend": "flash_attention_2" + }, + "pre_train": { + "output_dir": "checkpoints/seqlen128_mla-on_q96_k64_o96", + "seed": 42, + "train_batch_size": 128, + "gradient_accumulation_steps": 8, + "learning_rate": 0.0005, + "num_train_steps": 12500, + "eval_steps": 1000, + "weight_decay": 0.01, + "num_workers": 8, + "pin_memory": true, + "dataset_name": "wikitext", + "dataset_config": "wikitext-103-raw-v1", + "max_seq_length": 128, + "eval_batch_size": 32, + "bf16": true, + "fp16": false, + "torch_compile": true, + "torch_compile_backend": "inductor", + "torch_compile_mode": "default" + }, + "fine_tune": { + "task": "sst2", + "tokenizer_name_or_path": "gpt2", + "method": "lm_label_words", + "label_words": { + "0": " negative", + "1": " positive" + }, + "train_batch_size": 256, + "gradient_accumulation_steps": 1, + "eval_batch_size": 256, + "learning_rate": 5e-05, + "weight_decay": 0.05, + "max_steps": 1500, + "warmup_ratio": 0.1, + "eval_steps": 150, + "logging_steps": 20, + "save_strategy": "no", + "save_total_limit": 0, + "report_to_wandb": true, + "bf16": true, + "fp16": false, + "torch_compile": false, + "torch_compile_backend": null, + "torch_compile_mode": null, + "lora": { + "enabled": false + }, + "seed": 42, + "max_seq_length": 128, + "output_dir": "checkpoints/seqlen128_mla-on_q96_k64_o96/ft_sst2", + "run_name": "ft-sst2 - seqlen.128 - mla-on.96.64.96 - mlp.1024 - model.256.lyr.6 - ah.8.32" + } +} diff --git a/subspace_decoder/configs/seqlen128_mla_q96_k64.json b/subspace_decoder/configs/seqlen128_mla_q96_k64.json new file mode 100644 index 0000000..20ad530 --- /dev/null +++ b/subspace_decoder/configs/seqlen128_mla_q96_k64.json @@ -0,0 +1,104 @@ +{ + "shorthand": "seqlen.128 - mla.96.64 - mlp.1024 - model.256.lyr.6 - ah.8.32", + "notes": "Baseline MLA, Sequence length is 128.", + "model": { + "hidden_size": 256, + "num_hidden_layers": 6, + "num_nextn_predict_layers": 1, + "moe_intermediate_size": 128, + "intermediate_size": 1024, + "n_shared_experts": 1, + "n_routed_experts": 4, + "ep_size": 1, + "routed_scaling_factor": 1, + "num_experts_per_tok": 2, + "moe_layer_freq": 2, + "first_k_dense_replace": 1, + "topk_method": "softmax_aux", + "n_group": 1, + "topk_group": 1, + "norm_topk_prob": true, + "scoring_func": "softmax", + "hidden_act": "silu", + "use_cache": false, + "pad_token_id": 50256, + "bos_token_id": 50256, + "eos_token_id": 50256, + "tie_word_embeddings": true, + "attention_dropout": 0.0, + "hidden_dropout_prob": 0.1, + "classifier_dropout": null, + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "vocab_size": 50257, + "rope_theta": 10000.0, + "rope_scaling": null, + "max_position_embeddings": 1024, + "kv_lora_rank": 64, + "q_lora_rank": 96, + "qk_rope_head_dim": 32, + "v_head_dim": 32, + "qk_nope_head_dim": 0, + "num_attention_heads": 8, + "num_key_value_heads": 8, + "attention_bias": false, + "use_output_subspace": false, + "o_proj_variant": "vanilla", + "o_latent_dim": null, + "attention_backend": "flash_attention_2" + }, + "pre_train": { + "output_dir": "checkpoints/seqlen128_mla_q96_k64", + "seed": 42, + "train_batch_size": 128, + "gradient_accumulation_steps": 8, + "learning_rate": 0.0005, + "num_train_steps": 12500, + "eval_steps": 1000, + "weight_decay": 0.01, + "num_workers": 8, + "pin_memory": true, + "dataset_name": "wikitext", + "dataset_config": "wikitext-103-raw-v1", + "max_seq_length": 128, + "eval_batch_size": 32, + "bf16": true, + "fp16": false, + "torch_compile": true, + "torch_compile_backend": "inductor", + "torch_compile_mode": "default" + }, + "fine_tune": { + "task": "sst2", + "tokenizer_name_or_path": "gpt2", + "method": "lm_label_words", + "label_words": { + "0": " negative", + "1": " positive" + }, + "train_batch_size": 256, + "gradient_accumulation_steps": 1, + "eval_batch_size": 256, + "learning_rate": 5e-05, + "weight_decay": 0.05, + "max_steps": 1500, + "warmup_ratio": 0.1, + "eval_steps": 150, + "logging_steps": 20, + "save_strategy": "no", + "save_total_limit": 0, + "report_to_wandb": true, + "bf16": true, + "fp16": false, + "torch_compile": false, + "torch_compile_backend": null, + "torch_compile_mode": null, + "lora": { + "enabled": false + }, + "seed": 42, + "max_seq_length": 128, + "output_dir": "checkpoints/seqlen128_mla_q96_k64/ft_sst2", + "run_name": "ft-sst2 - seqlen.128 - mla.96.64 - mlp.1024 - model.256.lyr.6 - ah.8.32" + } +} From 6e5aaa9d92c5c63dbd6e0ee6a3aa394fc85fb074 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Wed, 24 Sep 2025 11:20:47 -0500 Subject: [PATCH 2/4] Change config to make all layers dense --- .../configs/seqlen128_mla-on_q96_k64_o96.json | 12 +++++++----- subspace_decoder/configs/seqlen128_mla_q96_k64.json | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json index b53cfb8..9afe46b 100644 --- a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json +++ b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json @@ -7,13 +7,15 @@ "num_nextn_predict_layers": 1, "moe_intermediate_size": 128, "intermediate_size": 1024, - "n_shared_experts": 1, - "n_routed_experts": 4, + + "n_shared_experts": 0, + "n_routed_experts": 0, + "num_experts_per_tok": 0, + "moe_layer_freq": 0, + "first_k_dense_replace": 0, + "ep_size": 1, "routed_scaling_factor": 1, - "num_experts_per_tok": 2, - "moe_layer_freq": 2, - "first_k_dense_replace": 1, "topk_method": "softmax_aux", "n_group": 1, "topk_group": 1, diff --git a/subspace_decoder/configs/seqlen128_mla_q96_k64.json b/subspace_decoder/configs/seqlen128_mla_q96_k64.json index 20ad530..572192b 100644 --- a/subspace_decoder/configs/seqlen128_mla_q96_k64.json +++ b/subspace_decoder/configs/seqlen128_mla_q96_k64.json @@ -7,13 +7,15 @@ "num_nextn_predict_layers": 1, "moe_intermediate_size": 128, "intermediate_size": 1024, - "n_shared_experts": 1, - "n_routed_experts": 4, + + "n_shared_experts": 0, + "n_routed_experts": 0, + "num_experts_per_tok": 0, + "moe_layer_freq": 0, + "first_k_dense_replace": 0, + "ep_size": 1, "routed_scaling_factor": 1, - "num_experts_per_tok": 2, - "moe_layer_freq": 2, - "first_k_dense_replace": 1, "topk_method": "softmax_aux", "n_group": 1, "topk_group": 1, From 79074715839ca087829fc4dc7412b55d246c96ac Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Thu, 25 Sep 2025 12:02:58 +0000 Subject: [PATCH 3/4] Tweaked params in config --- .../configs/seqlen128_mla-on_q96_k64_o96.json | 10 ++++++---- subspace_decoder/configs/seqlen128_mla_q96_k64.json | 10 ++++++---- subspace_decoder/scripts/finetune_sst2.py | 6 +++--- subspace_decoder/scripts/train.py | 10 +++++----- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json index 9afe46b..bdc7b7e 100644 --- a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json +++ b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json @@ -5,12 +5,12 @@ "hidden_size": 256, "num_hidden_layers": 6, "num_nextn_predict_layers": 1, - "moe_intermediate_size": 128, + "moe_intermediate_size": 256, "intermediate_size": 1024, "n_shared_experts": 0, - "n_routed_experts": 0, - "num_experts_per_tok": 0, + "n_routed_experts": 4, + "num_experts_per_tok": 2, "moe_layer_freq": 0, "first_k_dense_replace": 0, @@ -68,7 +68,9 @@ "fp16": false, "torch_compile": true, "torch_compile_backend": "inductor", - "torch_compile_mode": "default" + "torch_compile_mode": "default", + + "best_checkpoint": "checkpoints/seqlen128_mla-on_q96_k64_o96/checkpoint-12000" }, "fine_tune": { "task": "sst2", diff --git a/subspace_decoder/configs/seqlen128_mla_q96_k64.json b/subspace_decoder/configs/seqlen128_mla_q96_k64.json index 572192b..7811aea 100644 --- a/subspace_decoder/configs/seqlen128_mla_q96_k64.json +++ b/subspace_decoder/configs/seqlen128_mla_q96_k64.json @@ -5,12 +5,12 @@ "hidden_size": 256, "num_hidden_layers": 6, "num_nextn_predict_layers": 1, - "moe_intermediate_size": 128, + "moe_intermediate_size": 256, "intermediate_size": 1024, "n_shared_experts": 0, - "n_routed_experts": 0, - "num_experts_per_tok": 0, + "n_routed_experts": 4, + "num_experts_per_tok": 2, "moe_layer_freq": 0, "first_k_dense_replace": 0, @@ -68,7 +68,9 @@ "fp16": false, "torch_compile": true, "torch_compile_backend": "inductor", - "torch_compile_mode": "default" + "torch_compile_mode": "default", + + "best_checkpoint": "checkpoints/seqlen128_mla_q96_k64/checkpoint-12000" }, "fine_tune": { "task": "sst2", diff --git a/subspace_decoder/scripts/finetune_sst2.py b/subspace_decoder/scripts/finetune_sst2.py index 2505558..4d00bf2 100644 --- a/subspace_decoder/scripts/finetune_sst2.py +++ b/subspace_decoder/scripts/finetune_sst2.py @@ -26,13 +26,13 @@ set_seed, ) -from utils import summarize_parameters, format_size - # Project import path (same pattern as your train.py) PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) +from utils import summarize_parameters, format_size + from layers.patch_o_proj import load_checkpoint_state_dict, load_and_patch_model, Variant from transformers import DeepseekV3Config, DeepseekV3ForCausalLM @@ -409,7 +409,7 @@ def __call__(self, eval_pred, compute_result=False): wandb.init( project="decoder-finetune-sst2", - name=ft.get("run_name", f"ft-sst2-{run_name}"), + name="balance param-compute budget" + " " + ft.get("run_name", f"ft-sst2-{run_name}"), config=full_cfg ) diff --git a/subspace_decoder/scripts/train.py b/subspace_decoder/scripts/train.py index 68e6c24..64faa7f 100644 --- a/subspace_decoder/scripts/train.py +++ b/subspace_decoder/scripts/train.py @@ -35,10 +35,6 @@ set_seed, ) -from utils import summarize_parameters, format_size -# To disable a warning. -os.environ["TOKENIZERS_PARALLELISM"] = "false" - # Make sure we can import modules from the decoder package PROJECT_ROOT = Path(__file__).resolve().parents[1] @@ -47,6 +43,10 @@ if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) +from utils import summarize_parameters, format_size +# To disable a warning. +os.environ["TOKENIZERS_PARALLELISM"] = "false" + from layers.patch_o_proj import patch_o_proj_implementation from transformers import DeepseekV3Config, DeepseekV3ForCausalLM @@ -299,7 +299,7 @@ def group_texts(examples): wandb.init( project="decoder-pretrain-wiki103", - name=ptrain_cfg["run_name"], + name=f'balance param-compute budget {ptrain_cfg["run_name"]}', config=full_cfg ) From 0fc07299ab39f4a264db6f88767268c6272f1766 Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Sat, 27 Sep 2025 09:10:14 -0500 Subject: [PATCH 4/4] Make some cleanup --- subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json | 4 +--- subspace_decoder/configs/seqlen128_mla_q96_k64.json | 4 +--- subspace_decoder/scripts/finetune_sst2.py | 2 +- subspace_decoder/scripts/train.py | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json index bdc7b7e..5ca5da3 100644 --- a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json +++ b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json @@ -68,9 +68,7 @@ "fp16": false, "torch_compile": true, "torch_compile_backend": "inductor", - "torch_compile_mode": "default", - - "best_checkpoint": "checkpoints/seqlen128_mla-on_q96_k64_o96/checkpoint-12000" + "torch_compile_mode": "default" }, "fine_tune": { "task": "sst2", diff --git a/subspace_decoder/configs/seqlen128_mla_q96_k64.json b/subspace_decoder/configs/seqlen128_mla_q96_k64.json index 7811aea..ace0c01 100644 --- a/subspace_decoder/configs/seqlen128_mla_q96_k64.json +++ b/subspace_decoder/configs/seqlen128_mla_q96_k64.json @@ -68,9 +68,7 @@ "fp16": false, "torch_compile": true, "torch_compile_backend": "inductor", - "torch_compile_mode": "default", - - "best_checkpoint": "checkpoints/seqlen128_mla_q96_k64/checkpoint-12000" + "torch_compile_mode": "default" }, "fine_tune": { "task": "sst2", diff --git a/subspace_decoder/scripts/finetune_sst2.py b/subspace_decoder/scripts/finetune_sst2.py index 4d00bf2..5f6c745 100644 --- a/subspace_decoder/scripts/finetune_sst2.py +++ b/subspace_decoder/scripts/finetune_sst2.py @@ -409,7 +409,7 @@ def __call__(self, eval_pred, compute_result=False): wandb.init( project="decoder-finetune-sst2", - name="balance param-compute budget" + " " + ft.get("run_name", f"ft-sst2-{run_name}"), + name=ft.get("run_name", f"ft-sst2-{run_name}"), config=full_cfg ) diff --git a/subspace_decoder/scripts/train.py b/subspace_decoder/scripts/train.py index 64faa7f..772371f 100644 --- a/subspace_decoder/scripts/train.py +++ b/subspace_decoder/scripts/train.py @@ -299,7 +299,7 @@ def group_texts(examples): wandb.init( project="decoder-pretrain-wiki103", - name=f'balance param-compute budget {ptrain_cfg["run_name"]}', + name=ptrain_cfg["run_name"], config=full_cfg )