From 5b98c861306771b27a18ab61e47f6334a45bab2a Mon Sep 17 00:00:00 2001 From: Tony Date: Tue, 9 Jun 2026 21:30:32 -0400 Subject: [PATCH 1/2] Batch eval generation across turns via --eval_batch_size The evaluator generated one turn at a time: for every sample it looped over response ranges and called generate() per turn with batch size 1. Restructure into two phases: flatten_eval_turns expands every (sample, turn) pair into a flat work item (prefix ids, masked signal indices, per-sample encoder outputs, ground truth), then generation runs over chunks of --eval_batch_size turns, left-padded to the chunk max with signal indices shifted per item. Results are reassembled in the original order before the (unchanged) metric computation. --eval_batch_size defaults to 1, which preserves today's behavior exactly: same generate() calls in the same order, verified on a real eval (277 turn pairs, sampling on): 277/277 generations and all metrics identical to main. With batching, greedy decoding at eval_batch_size=4 reproduces 75-77% of eval_batch_size=1 generations exactly (signal and rgb configs, untrained connector); the remainder differ through bfloat16 batched kernels reaching different logit argmaxes on near-ties. Aggregate metrics agree to the third decimal. index_nested now returns the squeezed per-sample entry; its only caller is the new flatten step. --- src/configs/config.py | 2 + src/runners/evaluator.py | 144 +++++++++++++++++++++++++-------------- 2 files changed, 93 insertions(+), 53 deletions(-) diff --git a/src/configs/config.py b/src/configs/config.py index 42fbf7a..6274bd5 100644 --- a/src/configs/config.py +++ b/src/configs/config.py @@ -59,6 +59,8 @@ def get_args(mode: Mode) -> argparse.Namespace: help="Training phase: pretrain (raw text + bos/signal/eos, no chat template), sft (chat template), rl (sft + think/answer special tokens)") parser.add_argument("--explicit_thinking", action="store_true", default=False, help="Treat \\n as a fixed prompt prefix: mask loss up to and including \\n (SFT); inject it at generation to force thinking.") + if mode in {"eval", "inference"}: + parser.add_argument("--eval_batch_size", type=int, default=1, help="Number of turns generated per batch during eval/inference") if mode == "train": parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "adamw", "muon"], help="Optimizer type") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") diff --git a/src/runners/evaluator.py b/src/runners/evaluator.py index 0df03c3..d3fdfa5 100644 --- a/src/runners/evaluator.py +++ b/src/runners/evaluator.py @@ -212,7 +212,68 @@ def run_statistical_analysis(all_seeds_results): return statistical_results def index_nested(encoder_tokenizer_out, batch): - return {k: index_nested(v, batch) if isinstance(v, dict) else v[batch:batch+1] for k, v in encoder_tokenizer_out.items()} + return {k: index_nested(v, batch) if isinstance(v, dict) else v[batch] for k, v in encoder_tokenizer_out.items()} + + +def stack_nested(items): + return {k: stack_nested([item[k] for item in items]) if isinstance(items[0][k], dict) + else torch.stack([item[k] for item in items], dim=0) for k in items[0]} + + +def flatten_eval_turns(dataloader, args, needs_signal_injection): + """Expand every (sample, turn) pair into a flat generation work item.""" + dataset = dataloader.dataset + turns = [] + progress = tqdm(dataloader, desc="Flattening turns", disable=not is_main(), leave=False) + for batch_idx, batch in enumerate(progress): + B = batch["elm_input_ids"].shape[0] + for b in range(B): + full_ids = batch["elm_input_ids"][b].tolist() + full_attn = batch["elm_attention_mask"][b].tolist() + if needs_signal_injection: + signal_indices = batch["signal_id_indices"][b] + encoder_out = index_nested(batch["encoder_tokenizer_out"], b) + ranges = dataset.get_response_ranges(full_ids) + gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) + if getattr(args, "dev", False): + print(f"\n--- Batch {batch_idx}, Sample {b} ---") + print(f"Total turns: {len(ranges)}") + dataset.assert_range_alignment(full_ids, ranges) + for (s, _), gt in zip(ranges, gt_texts): + turn = {"order": len(turns), "prefix_ids": full_ids[:s], + "prefix_attn": full_attn[:s], "gt_text": gt} + if needs_signal_injection: + masked_indices = signal_indices.clone() + masked_indices[masked_indices >= s] = -1 + turn["signal_id_indices"] = masked_indices + turn["encoder_tokenizer_out"] = encoder_out + turns.append(turn) + if train_dev_break(getattr(args, "dev", False), batch, 0): + break + return turns + + +def collate_turns(chunk, pad_token_id): + """Left-pad a chunk of turn items to its max prefix length for generation.""" + max_len = max(len(turn["prefix_ids"]) for turn in chunk) + input_ids, attention_mask = [], [] + for turn in chunk: + pad = max_len - len(turn["prefix_ids"]) + input_ids.append([pad_token_id] * pad + turn["prefix_ids"]) + attention_mask.append([0] * pad + turn["prefix_attn"]) + gen_batch = { + "elm_input_ids": torch.tensor(input_ids, dtype=torch.int64), + "elm_attention_mask": torch.tensor(attention_mask, dtype=torch.float32), + } + if "signal_id_indices" in chunk[0]: + shifted = [] + for turn in chunk: + pad = max_len - len(turn["prefix_ids"]) + indices = turn["signal_id_indices"] + shifted.append(torch.where(indices >= 0, indices + pad, indices)) + gen_batch["signal_id_indices"] = torch.stack(shifted, dim=0) + gen_batch["encoder_tokenizer_out"] = stack_nested([turn["encoder_tokenizer_out"] for turn in chunk]) + return gen_batch def pretrain_diagnostic_breakdown(refs, hyps): split = lambda s: {x.strip() for x in (s or "").split(";") if x.strip()} @@ -299,65 +360,42 @@ def save_incorrect_predictions_histogram_png(references, hypotheses, path, top_k print(f"Saved incorrect-predictions histogram to {path}") def evaluate(elm, dataloader, args): - show_progress = is_main() elm.eval() needs_signal_injection = args.elm in ("mlp_llava", "linear_llava", "base_elf", "patch_elf", "conv_elf") - progress = tqdm( - dataloader, - desc=f"LLM: {args.llm} ENCODER: {args.encoder}", - disable=not show_progress, - leave=False, - ) dataset = dataloader.dataset device = next(elm.parameters()).device - all_refs, all_hyps, all_prompts = [], [], [] + + turns = flatten_eval_turns(dataloader, args, needs_signal_injection) + + eval_batch_size = getattr(args, "eval_batch_size", 1) + pad_token_id = dataset.llm_tokenizer.pad_token_id + results = [] # (order, gt_text, gen_txt, prefix_ids) + progress = tqdm(range(0, len(turns), eval_batch_size), + desc=f"LLM: {args.llm} ENCODER: {args.encoder} (eval_bs={eval_batch_size})", + disable=not is_main(), leave=False) with torch.no_grad(): - for batch_idx, batch in enumerate(progress): - B = batch["elm_input_ids"].shape[0] - for b in range(B): - full_ids = batch["elm_input_ids"][b].tolist() - full_attn = batch["elm_attention_mask"][b].tolist() - if needs_signal_injection: - signal_indices = batch["signal_id_indices"][b] - full_encoder_tokenizer_out = index_nested(batch["encoder_tokenizer_out"], b) - ranges = dataset.get_response_ranges(full_ids) - gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) + for start in progress: + chunk = turns[start:start + eval_batch_size] + gen_batch = collate_turns(chunk, pad_token_id) + gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} + gen_out = elm.generate(**gen_batch, max_new_tokens=args.max_new_tokens) + for turn, row in zip(chunk, gen_out): + gen_txt = dataset.get_generated_response_for_turn(turn["prefix_ids"], row.cpu().tolist()) if getattr(args, "dev", False): - print(f"\n--- Batch {batch_idx}, Sample {b} ---") - print(f"Total turns: {len(ranges)}") - dataset.assert_range_alignment(full_ids, ranges) - for turn_idx, ((s, _), gt) in enumerate(zip(ranges, gt_texts)): - sub_ids = full_ids[:s] - sub_attn = full_attn[:s] - gen_batch = { - "elm_input_ids": torch.tensor(sub_ids, dtype=torch.int64).unsqueeze(0), - "elm_attention_mask": torch.tensor(sub_attn, dtype=torch.float32).unsqueeze(0), - "max_new_tokens": args.max_new_tokens - } - if needs_signal_injection: - gen_batch["encoder_tokenizer_out"] = full_encoder_tokenizer_out - truncated_len = len(sub_ids) - masked_indices = signal_indices.clone() - masked_indices[masked_indices >= truncated_len] = -1 - gen_batch["signal_id_indices"] = masked_indices - gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} - gen_out = elm.generate(**gen_batch)[0].cpu().tolist() - gen_txt = dataset.get_generated_response_for_turn(sub_ids, gen_out) - if getattr(args, "dev", False): - print(f"\nTurn {turn_idx + 1}:") - print(f"\nGround Truth:\n{gt}") - print(f"\nGenerated:\n{gen_txt}") - print("-" * 100) - if gt and gen_txt: - all_prompts.append(dataset.llm_tokenizer.decode(sub_ids, skip_special_tokens=True).strip()) - all_refs.append(gt) - all_hyps.append(gen_txt) - if train_dev_break(getattr(args, "dev", False), batch, 0): - break - # if batch_idx == 10: - # break - # input() + print(f"\nTurn (order {turn['order']}):") + print(f"\nGround Truth:\n{turn['gt_text']}") + print(f"\nGenerated:\n{gen_txt}") + print("-" * 100) + results.append((turn["order"], turn["gt_text"], gen_txt, turn["prefix_ids"])) + + results.sort(key=lambda r: r[0]) + all_refs, all_hyps, all_prompts = [], [], [] + for _, gt, gen_txt, prefix_ids in results: + if gt and gen_txt: + all_prompts.append(dataset.llm_tokenizer.decode(prefix_ids, skip_special_tokens=True).strip()) + all_refs.append(gt) + all_hyps.append(gen_txt) refs_t, refs_a = map(list, zip(*map(split_response, all_refs))) if all_refs else ([], []) hyps_t, hyps_a = map(list, zip(*map(split_response, all_hyps))) if all_hyps else ([], []) think_pairs = [(r, h) for r, h in zip(refs_t, hyps_t) if r and h] From 56d18e789c0b076f25e3468824649691a4343e08 Mon Sep 17 00:00:00 2001 From: Tony Date: Tue, 9 Jun 2026 21:38:12 -0400 Subject: [PATCH 2/2] Support distributed eval via --distributed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Shard the flattened turn list across ranks (turns[rank::world_size] — an exact partition, so no post-gather deduplication is needed), generate locally, then all_gather_object the results and sort by the original order; every rank computes identical metrics from the full gathered set. The model is unwrapped from DDP for generate(). main_evaluator initializes/destroys the process group when --distributed is passed and gates prints and file writes to rank 0. Single-process behavior is unchanged (world_size == 1 keeps the full turn list and skips the gather). --- src/main_evaluator.py | 41 ++++++++++++++++++++--------------- src/runners/evaluator.py | 46 ++++++++++++++++++++++++++-------------- 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/src/main_evaluator.py b/src/main_evaluator.py index 4321dae..c64fc82 100644 --- a/src/main_evaluator.py +++ b/src/main_evaluator.py @@ -5,7 +5,7 @@ from pathlib import Path from configs.config import get_args -from utils.gpu_manager import GPUSetup +from utils.gpu_manager import GPUSetup, init_dist, cleanup, is_main from utils.seed_manager import set_seed from dataloaders.build_dataloader import BuildDataLoader from elms.build_elm import BuildELM @@ -19,6 +19,8 @@ def main(): mode = "eval" args = get_args(mode) args.mode = mode + if args.distributed: + init_dist() # folds = ["1", "2", "3", "4", "5"] # seeds = [1337, 1338, 1339, 1340, 1341] folds = ["1"] @@ -36,7 +38,8 @@ def main(): results_file = os.path.join(checkpoint_dir, f"{ckpt_file_name}_{data_name}_{sys_prompt_name}_{args.perturb}_{args.max_new_tokens}.json") for fold in folds: for seed in seeds: - print(f"Evaluating fold {fold} with seed {seed}") + if is_main(): + print(f"Evaluating fold {fold} with seed {seed}") args.fold = fold args.seed = seed set_seed(args.seed) @@ -50,7 +53,7 @@ def main(): gpu_setup.print_model_device(elm, f"{args.llm}_{args.encoder}") out = evaluate(elm, dataloader, args) all_metrics.append(out) - if len(all_metrics) == 1: + if is_main() and len(all_metrics) == 1: examples_path = results_file.replace(".json", f"examples_{args.max_new_tokens}.json") examples = [{"prompt": p, "predicted": h, "ground_truth": r} for p, h, r in zip(out["prompts"], out["hypotheses"], out["references"])] @@ -60,20 +63,24 @@ def main(): del elm, elm_components, build_elm, gpu_setup, dataloader, build_dataloader gc.collect() torch.cuda.empty_cache() - if "confusion_matrix" in out: - cm_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}.png") - save_confusion_matrix_png(out["confusion_matrix"], cm_path) - other_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_other.png") - save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10) - incorrect_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_incorrect.png") - save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path) - if "pretrain_breakdown" in out: - save_pretrain_breakdown_pngs(out["pretrain_breakdown"], - results_file.replace(".json", f"{fold}_{seed}_pretrain")) - statistical_results = run_statistical_analysis(all_metrics) - with open(results_file, "w") as f: - json.dump(statistical_results, f, indent=2) - print(f"Saved evaluation results to {results_file}") + if is_main(): + if "confusion_matrix" in out: + cm_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}.png") + save_confusion_matrix_png(out["confusion_matrix"], cm_path) + other_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_other.png") + save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10) + incorrect_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_incorrect.png") + save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path) + if "pretrain_breakdown" in out: + save_pretrain_breakdown_pngs(out["pretrain_breakdown"], + results_file.replace(".json", f"{fold}_{seed}_pretrain")) + if is_main(): + statistical_results = run_statistical_analysis(all_metrics) + with open(results_file, "w") as f: + json.dump(statistical_results, f, indent=2) + print(f"Saved evaluation results to {results_file}") + if args.distributed: + cleanup() if __name__ == "__main__": diff --git a/src/runners/evaluator.py b/src/runners/evaluator.py index d3fdfa5..ce78a76 100644 --- a/src/runners/evaluator.py +++ b/src/runners/evaluator.py @@ -12,7 +12,9 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt -from utils.gpu_manager import is_main, train_dev_break, batch_to_device +import torch.distributed as dist + +from utils.gpu_manager import is_main, train_dev_break, batch_to_device, get_rank, get_world_size _THINK_RE = re.compile(r"(.*?)", re.DOTALL) _ANSWER_RE = re.compile(r"(.*?)", re.DOTALL) @@ -367,28 +369,37 @@ def evaluate(elm, dataloader, args): device = next(elm.parameters()).device turns = flatten_eval_turns(dataloader, args, needs_signal_injection) + # Every rank flattens the full dataset identically; each generates a + # disjoint stride of the turn list and results are gathered afterwards. + world_size = get_world_size() + local_turns = turns[get_rank()::world_size] if world_size > 1 else turns + gen_model = elm.module if hasattr(elm, "module") else elm # unwrap DDP for generate() eval_batch_size = getattr(args, "eval_batch_size", 1) pad_token_id = dataset.llm_tokenizer.pad_token_id results = [] # (order, gt_text, gen_txt, prefix_ids) - progress = tqdm(range(0, len(turns), eval_batch_size), + progress = tqdm(range(0, len(local_turns), eval_batch_size), desc=f"LLM: {args.llm} ENCODER: {args.encoder} (eval_bs={eval_batch_size})", disable=not is_main(), leave=False) with torch.no_grad(): for start in progress: - chunk = turns[start:start + eval_batch_size] + chunk = local_turns[start:start + eval_batch_size] gen_batch = collate_turns(chunk, pad_token_id) gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} - gen_out = elm.generate(**gen_batch, max_new_tokens=args.max_new_tokens) + gen_out = gen_model.generate(**gen_batch, max_new_tokens=args.max_new_tokens) for turn, row in zip(chunk, gen_out): gen_txt = dataset.get_generated_response_for_turn(turn["prefix_ids"], row.cpu().tolist()) - if getattr(args, "dev", False): + if getattr(args, "dev", False) and is_main(): print(f"\nTurn (order {turn['order']}):") print(f"\nGround Truth:\n{turn['gt_text']}") print(f"\nGenerated:\n{gen_txt}") print("-" * 100) results.append((turn["order"], turn["gt_text"], gen_txt, turn["prefix_ids"])) + if world_size > 1: + gathered = [None] * world_size + dist.all_gather_object(gathered, results) + results = [r for shard in gathered for r in shard] results.sort(key=lambda r: r[0]) all_refs, all_hyps, all_prompts = [], [], [] for _, gt, gen_txt, prefix_ids in results: @@ -402,12 +413,13 @@ def evaluate(elm, dataloader, args): results = {"answer": evaluate_strings(refs_a, hyps_a)} if think_pairs: results["thinking"] = evaluate_strings(*map(list, zip(*think_pairs))) - print("\n=== N-Turn Evaluation (generated vs. gold response only) ===") - print(f"Pairs: {len(all_refs)} (thinking pairs: {len(think_pairs)})") - for group, mdict in results.items(): - print(f"[{group}]") - for k, v in mdict.items(): - print(f" {k}: {v:.4f}") + if is_main(): + print("\n=== N-Turn Evaluation (generated vs. gold response only) ===") + print(f"Pairs: {len(all_refs)} (thinking pairs: {len(think_pairs)})") + for group, mdict in results.items(): + print(f"[{group}]") + for k, v in mdict.items(): + print(f" {k}: {v:.4f}") out = { "num_pairs": len(all_refs), "metrics": results, @@ -418,13 +430,15 @@ def evaluate(elm, dataloader, args): if getattr(args, "train_phase", "sft") == "pretrain" and refs_a: breakdown = pretrain_diagnostic_breakdown(refs_a, hyps_a) out["pretrain_breakdown"] = breakdown - print(f"\n=== Pretrain diagnostic breakdown (N={breakdown['n']:,}) ===") - print(f" Matched={breakdown['matched']:,} Not matched={breakdown['not_matched']:,} Other={breakdown['other']:,}") - print(f" Missed_inst={breakdown['missed_inst']:,} Extra_inst={breakdown['extra_inst']:,}") - print(f" Only missed={breakdown['only_missed']:,} Only extra={breakdown['only_extra']:,} Both={breakdown['both']:,}") + if is_main(): + print(f"\n=== Pretrain diagnostic breakdown (N={breakdown['n']:,}) ===") + print(f" Matched={breakdown['matched']:,} Not matched={breakdown['not_matched']:,} Other={breakdown['other']:,}") + print(f" Missed_inst={breakdown['missed_inst']:,} Extra_inst={breakdown['extra_inst']:,}") + print(f" Only missed={breakdown['only_missed']:,} Only extra={breakdown['only_extra']:,} Both={breakdown['both']:,}") if any(d.startswith("ecg-comp") for d in args.data): per_class_acc, confusion_matrix, other_counts = compute_classification_metrics(refs_a, hyps_a) - print_classification_metrics(per_class_acc, confusion_matrix) + if is_main(): + print_classification_metrics(per_class_acc, confusion_matrix) results["per_class_acc"] = per_class_acc out["confusion_matrix"] = confusion_matrix out["other_output_counts"] = other_counts