Skip to content

LLaMA3-CODI-8B HF checkpoint evaluates far below the reported GSM8K-Aug accuracy #15

@epc314

Description

@epc314

Hi SIM-CoT team,

Thank you for sharing this great work and releasing the CODI checkpoints. I am trying to reproduce the HF-released LLaMA CODI checkpoints with the official InternLM/SIM-CoT evaluation code, and I found that the 1B and 3B checkpoints reproduce the reported GSM8K-Aug numbers, but the 8B checkpoint does not.

Summary

Using the official CODI generation loop from CODI/test.py at commit d1d56afbe705cfcbf5911b588da6c3083825598d, I get:

HF checkpoint README reported GSM8K-Aug CODI SIM-CoT My reproduced GSM8K test acc
internlm/SIM_COT-LLaMA3-CODI-1B 55.6 55.88
internlm/SIM_COT-LLaMA3-CODI-3B 62.3 62.09
internlm/SIM_COT-LLaMA3-CODI-8B 64.1 34.27

This suggests that the evaluation environment and reproduction path are basically correct for 1B/3B, but something is wrong or underspecified for the released 8B checkpoint.

Additional packaging issue

The official HF repos all have config.json with size 0 bytes. Also, the official CODI/test.py currently loads only:

load_file(os.path.join(model_args.ckpt_dir, "model.safetensors"))
# fallback:
torch.load(os.path.join(model_args.ckpt_dir, "pytorch_model.bin"))

This works for the 1B release because it has a single model.safetensors, but the 3B and 8B releases are sharded and only provide:

model-000xx-of-000xx.safetensors
model.safetensors.index.json

So the official script cannot directly load the released 3B/8B checkpoints without a small sharded-safetensors loader.

After adding only a sharded loader, the 3B checkpoint reproduces the README number, while 8B remains far below the README number.

Minimal changes made to run the official evaluator locally

I used the official CODI/test.py generation logic and changed only the following:

  1. Removed the interactive import pdb; pdb.set_trace() left inside the decoding loop.
  2. Replaced the hard-coded author-local GSM8K path with my local GSM8K test JSON path.
  3. Added support for loading model.safetensors.index.json sharded checkpoints.
  4. Redirected the output JSON path from the author-local /mnt/shared-storage-user/... path to a local output directory.

The sharded loader is:

def load_codi_checkpoint_state_dict(ckpt_dir):
    single_path = os.path.join(ckpt_dir, "model.safetensors")
    if os.path.exists(single_path):
        return load_file(single_path)

    index_path = os.path.join(ckpt_dir, "model.safetensors.index.json")
    if os.path.exists(index_path):
        with open(index_path, "r", encoding="utf-8") as f:
            index = json.load(f)
        state_dict = {}
        for shard_name in sorted(set(index["weight_map"].values())):
            state_dict.update(load_file(os.path.join(ckpt_dir, shard_name)))
        return state_dict

    return torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"))

For the 8B checkpoint, load_state_dict(strict=False) reports no missing main-model keys; the unexpected keys are decoder.*, which is expected when evaluating with use_decoder=False.

Commands

1B

CUDA_VISIBLE_DEVICES=0 python test.py \
  --data_name gsm8k \
  --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
  --seed 11 \
  --model_max_length 512 \
  --bf16 true \
  --lora_r 128 --lora_alpha 32 --lora_init true \
  --batch_size 128 \
  --greedy True \
  --num_latent 6 \
  --use_prj True \
  --prj_dim 2048 \
  --prj_no_ln False \
  --prj_dropout 0.0 \
  --inf_latent_iterations 6 \
  --inf_num_iterations 1 \
  --remove_eos True \
  --use_lora True \
  --ckpt_dir <path-to-SIM_COT-LLaMA3-CODI-1B>

Result:

GSM8K test accuracy: 55.88%
average length of COT: 7.196

3B

CUDA_VISIBLE_DEVICES=0 python test.py \
  --data_name gsm8k \
  --model_name_or_path meta-llama/Llama-3.2-3B-Instruct \
  --seed 11 \
  --model_max_length 512 \
  --bf16 true \
  --lora_r 128 --lora_alpha 32 --lora_init true \
  --batch_size 128 \
  --greedy True \
  --num_latent 6 \
  --use_prj True \
  --prj_dim 3072 \
  --prj_no_ln False \
  --prj_dropout 0.0 \
  --inf_latent_iterations 6 \
  --inf_num_iterations 1 \
  --remove_eos True \
  --use_lora True \
  --ckpt_dir <path-to-SIM_COT-LLaMA3-CODI-3B>

Result after adding sharded checkpoint loading:

GSM8K test accuracy: 62.09%
average length of COT: 7.193

8B

The README script points to a local Meta-Llama-3.1-8B-Instruct snapshot. Since the HF checkpoint has an empty config.json, I used the public LLaMA-3.1-8B-Instruct-compatible config/tokenizer from NousResearch/Meta-Llama-3.1-8B-Instruct. The released CODI state dict contains the full codi.base_model...base_layer weights, so the base model weights should be overwritten by the checkpoint.

CUDA_VISIBLE_DEVICES=0 python test.py \
  --data_name gsm8k \
  --model_name_or_path NousResearch/Meta-Llama-3.1-8B-Instruct \
  --seed 11 \
  --model_max_length 512 \
  --bf16 true \
  --lora_r 128 --lora_alpha 32 --lora_init true \
  --batch_size 128 \
  --greedy True \
  --num_latent 6 \
  --use_prj True \
  --prj_dim 4096 \
  --prj_no_ln False \
  --prj_dropout 0.0 \
  --inf_latent_iterations 6 \
  --inf_num_iterations 1 \
  --remove_eos True \
  --use_lora True \
  --ckpt_dir <path-to-SIM_COT-LLaMA3-CODI-8B>

Result:

GSM8K test accuracy: 34.27%
average length of COT: 7.670

I also tried the same 8B checkpoint with unsloth/Meta-Llama-3.1-8B-Instruct as the base config/tokenizer in a direct teacher-generation sanity check, and got the same low-accuracy range (~34.8%).

HF checkpoint revisions and file structure

internlm/SIM_COT-LLaMA3-CODI-1B
revision: 6e16fe4215025a9c48bc90d214c57349f247b017
config.json: 0 bytes
model.safetensors: 6388325872 bytes

internlm/SIM_COT-LLaMA3-CODI-3B
revision: a2db0e50d53bdbf5f42d508899ab62c7ced5dbe7
config.json: 0 bytes
model-00001-of-00003.safetensors
model-00002-of-00003.safetensors
model-00003-of-00003.safetensors
model.safetensors.index.json

internlm/SIM_COT-LLaMA3-CODI-8B
revision: de14926b5164b3ecdfbd7d383f235a4d5b983d7d
config.json: 0 bytes
model-00001-of-00007.safetensors
...
model-00007-of-00007.safetensors
model.safetensors.index.json

Environment

OS/container: Linux
Python: 3.10.12
GPU: NVIDIA H200 NVL, 143771 MiB
Driver: 570.148.08
CUDA: 12.8

torch: 2.7.1+cu128
transformers: 4.49.0
datasets: 3.1.0
peft: 0.18.1
accelerate: 1.10.0
safetensors: 0.5.3
huggingface_hub: 0.35.3

Could you please check whether the released internlm/SIM_COT-LLaMA3-CODI-8B checkpoint is the correct 8B checkpoint used for the README table, and/or provide the exact base model revision/config/tokenizer needed to reproduce the reported 64.1% GSM8K-Aug result?

It would also help if the HF repos included valid config.json files and if CODI/test.py supported sharded safetensors via model.safetensors.index.json.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions