diff --git a/src/configs/config.py b/src/configs/config.py index 42fbf7a..6274bd5 100644 --- a/src/configs/config.py +++ b/src/configs/config.py @@ -59,6 +59,8 @@ def get_args(mode: Mode) -> argparse.Namespace: help="Training phase: pretrain (raw text + bos/signal/eos, no chat template), sft (chat template), rl (sft + think/answer special tokens)") parser.add_argument("--explicit_thinking", action="store_true", default=False, help="Treat \\n as a fixed prompt prefix: mask loss up to and including \\n (SFT); inject it at generation to force thinking.") + if mode in {"eval", "inference"}: + parser.add_argument("--eval_batch_size", type=int, default=1, help="Number of turns generated per batch during eval/inference") if mode == "train": parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "adamw", "muon"], help="Optimizer type") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") diff --git a/src/runners/evaluator.py b/src/runners/evaluator.py index 0df03c3..d3fdfa5 100644 --- a/src/runners/evaluator.py +++ b/src/runners/evaluator.py @@ -212,7 +212,68 @@ def run_statistical_analysis(all_seeds_results): return statistical_results def index_nested(encoder_tokenizer_out, batch): - return {k: index_nested(v, batch) if isinstance(v, dict) else v[batch:batch+1] for k, v in encoder_tokenizer_out.items()} + return {k: index_nested(v, batch) if isinstance(v, dict) else v[batch] for k, v in encoder_tokenizer_out.items()} + + +def stack_nested(items): + return {k: stack_nested([item[k] for item in items]) if isinstance(items[0][k], dict) + else torch.stack([item[k] for item in items], dim=0) for k in items[0]} + + +def flatten_eval_turns(dataloader, args, needs_signal_injection): + """Expand every (sample, turn) pair into a flat generation work item.""" + dataset = dataloader.dataset + turns = [] + progress = tqdm(dataloader, desc="Flattening turns", disable=not is_main(), leave=False) + for batch_idx, batch in enumerate(progress): + B = batch["elm_input_ids"].shape[0] + for b in range(B): + full_ids = batch["elm_input_ids"][b].tolist() + full_attn = batch["elm_attention_mask"][b].tolist() + if needs_signal_injection: + signal_indices = batch["signal_id_indices"][b] + encoder_out = index_nested(batch["encoder_tokenizer_out"], b) + ranges = dataset.get_response_ranges(full_ids) + gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) + if getattr(args, "dev", False): + print(f"\n--- Batch {batch_idx}, Sample {b} ---") + print(f"Total turns: {len(ranges)}") + dataset.assert_range_alignment(full_ids, ranges) + for (s, _), gt in zip(ranges, gt_texts): + turn = {"order": len(turns), "prefix_ids": full_ids[:s], + "prefix_attn": full_attn[:s], "gt_text": gt} + if needs_signal_injection: + masked_indices = signal_indices.clone() + masked_indices[masked_indices >= s] = -1 + turn["signal_id_indices"] = masked_indices + turn["encoder_tokenizer_out"] = encoder_out + turns.append(turn) + if train_dev_break(getattr(args, "dev", False), batch, 0): + break + return turns + + +def collate_turns(chunk, pad_token_id): + """Left-pad a chunk of turn items to its max prefix length for generation.""" + max_len = max(len(turn["prefix_ids"]) for turn in chunk) + input_ids, attention_mask = [], [] + for turn in chunk: + pad = max_len - len(turn["prefix_ids"]) + input_ids.append([pad_token_id] * pad + turn["prefix_ids"]) + attention_mask.append([0] * pad + turn["prefix_attn"]) + gen_batch = { + "elm_input_ids": torch.tensor(input_ids, dtype=torch.int64), + "elm_attention_mask": torch.tensor(attention_mask, dtype=torch.float32), + } + if "signal_id_indices" in chunk[0]: + shifted = [] + for turn in chunk: + pad = max_len - len(turn["prefix_ids"]) + indices = turn["signal_id_indices"] + shifted.append(torch.where(indices >= 0, indices + pad, indices)) + gen_batch["signal_id_indices"] = torch.stack(shifted, dim=0) + gen_batch["encoder_tokenizer_out"] = stack_nested([turn["encoder_tokenizer_out"] for turn in chunk]) + return gen_batch def pretrain_diagnostic_breakdown(refs, hyps): split = lambda s: {x.strip() for x in (s or "").split(";") if x.strip()} @@ -299,65 +360,42 @@ def save_incorrect_predictions_histogram_png(references, hypotheses, path, top_k print(f"Saved incorrect-predictions histogram to {path}") def evaluate(elm, dataloader, args): - show_progress = is_main() elm.eval() needs_signal_injection = args.elm in ("mlp_llava", "linear_llava", "base_elf", "patch_elf", "conv_elf") - progress = tqdm( - dataloader, - desc=f"LLM: {args.llm} ENCODER: {args.encoder}", - disable=not show_progress, - leave=False, - ) dataset = dataloader.dataset device = next(elm.parameters()).device - all_refs, all_hyps, all_prompts = [], [], [] + + turns = flatten_eval_turns(dataloader, args, needs_signal_injection) + + eval_batch_size = getattr(args, "eval_batch_size", 1) + pad_token_id = dataset.llm_tokenizer.pad_token_id + results = [] # (order, gt_text, gen_txt, prefix_ids) + progress = tqdm(range(0, len(turns), eval_batch_size), + desc=f"LLM: {args.llm} ENCODER: {args.encoder} (eval_bs={eval_batch_size})", + disable=not is_main(), leave=False) with torch.no_grad(): - for batch_idx, batch in enumerate(progress): - B = batch["elm_input_ids"].shape[0] - for b in range(B): - full_ids = batch["elm_input_ids"][b].tolist() - full_attn = batch["elm_attention_mask"][b].tolist() - if needs_signal_injection: - signal_indices = batch["signal_id_indices"][b] - full_encoder_tokenizer_out = index_nested(batch["encoder_tokenizer_out"], b) - ranges = dataset.get_response_ranges(full_ids) - gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) + for start in progress: + chunk = turns[start:start + eval_batch_size] + gen_batch = collate_turns(chunk, pad_token_id) + gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} + gen_out = elm.generate(**gen_batch, max_new_tokens=args.max_new_tokens) + for turn, row in zip(chunk, gen_out): + gen_txt = dataset.get_generated_response_for_turn(turn["prefix_ids"], row.cpu().tolist()) if getattr(args, "dev", False): - print(f"\n--- Batch {batch_idx}, Sample {b} ---") - print(f"Total turns: {len(ranges)}") - dataset.assert_range_alignment(full_ids, ranges) - for turn_idx, ((s, _), gt) in enumerate(zip(ranges, gt_texts)): - sub_ids = full_ids[:s] - sub_attn = full_attn[:s] - gen_batch = { - "elm_input_ids": torch.tensor(sub_ids, dtype=torch.int64).unsqueeze(0), - "elm_attention_mask": torch.tensor(sub_attn, dtype=torch.float32).unsqueeze(0), - "max_new_tokens": args.max_new_tokens - } - if needs_signal_injection: - gen_batch["encoder_tokenizer_out"] = full_encoder_tokenizer_out - truncated_len = len(sub_ids) - masked_indices = signal_indices.clone() - masked_indices[masked_indices >= truncated_len] = -1 - gen_batch["signal_id_indices"] = masked_indices - gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} - gen_out = elm.generate(**gen_batch)[0].cpu().tolist() - gen_txt = dataset.get_generated_response_for_turn(sub_ids, gen_out) - if getattr(args, "dev", False): - print(f"\nTurn {turn_idx + 1}:") - print(f"\nGround Truth:\n{gt}") - print(f"\nGenerated:\n{gen_txt}") - print("-" * 100) - if gt and gen_txt: - all_prompts.append(dataset.llm_tokenizer.decode(sub_ids, skip_special_tokens=True).strip()) - all_refs.append(gt) - all_hyps.append(gen_txt) - if train_dev_break(getattr(args, "dev", False), batch, 0): - break - # if batch_idx == 10: - # break - # input() + print(f"\nTurn (order {turn['order']}):") + print(f"\nGround Truth:\n{turn['gt_text']}") + print(f"\nGenerated:\n{gen_txt}") + print("-" * 100) + results.append((turn["order"], turn["gt_text"], gen_txt, turn["prefix_ids"])) + + results.sort(key=lambda r: r[0]) + all_refs, all_hyps, all_prompts = [], [], [] + for _, gt, gen_txt, prefix_ids in results: + if gt and gen_txt: + all_prompts.append(dataset.llm_tokenizer.decode(prefix_ids, skip_special_tokens=True).strip()) + all_refs.append(gt) + all_hyps.append(gen_txt) refs_t, refs_a = map(list, zip(*map(split_response, all_refs))) if all_refs else ([], []) hyps_t, hyps_a = map(list, zip(*map(split_response, all_hyps))) if all_hyps else ([], []) think_pairs = [(r, h) for r, h in zip(refs_t, hyps_t) if r and h]