diff --git a/src/configs/config.py b/src/configs/config.py index 42fbf7a..6274bd5 100644 --- a/src/configs/config.py +++ b/src/configs/config.py @@ -59,6 +59,8 @@ def get_args(mode: Mode) -> argparse.Namespace: help="Training phase: pretrain (raw text + bos/signal/eos, no chat template), sft (chat template), rl (sft + think/answer special tokens)") parser.add_argument("--explicit_thinking", action="store_true", default=False, help="Treat \\n as a fixed prompt prefix: mask loss up to and including \\n (SFT); inject it at generation to force thinking.") + if mode in {"eval", "inference"}: + parser.add_argument("--eval_batch_size", type=int, default=1, help="Number of turns generated per batch during eval/inference") if mode == "train": parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "adamw", "muon"], help="Optimizer type") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") diff --git a/src/main_evaluator.py b/src/main_evaluator.py index 4321dae..c64fc82 100644 --- a/src/main_evaluator.py +++ b/src/main_evaluator.py @@ -5,7 +5,7 @@ from pathlib import Path from configs.config import get_args -from utils.gpu_manager import GPUSetup +from utils.gpu_manager import GPUSetup, init_dist, cleanup, is_main from utils.seed_manager import set_seed from dataloaders.build_dataloader import BuildDataLoader from elms.build_elm import BuildELM @@ -19,6 +19,8 @@ def main(): mode = "eval" args = get_args(mode) args.mode = mode + if args.distributed: + init_dist() # folds = ["1", "2", "3", "4", "5"] # seeds = [1337, 1338, 1339, 1340, 1341] folds = ["1"] @@ -36,7 +38,8 @@ def main(): results_file = os.path.join(checkpoint_dir, f"{ckpt_file_name}_{data_name}_{sys_prompt_name}_{args.perturb}_{args.max_new_tokens}.json") for fold in folds: for seed in seeds: - print(f"Evaluating fold {fold} with seed {seed}") + if is_main(): + print(f"Evaluating fold {fold} with seed {seed}") args.fold = fold args.seed = seed set_seed(args.seed) @@ -50,7 +53,7 @@ def main(): gpu_setup.print_model_device(elm, f"{args.llm}_{args.encoder}") out = evaluate(elm, dataloader, args) all_metrics.append(out) - if len(all_metrics) == 1: + if is_main() and len(all_metrics) == 1: examples_path = results_file.replace(".json", f"examples_{args.max_new_tokens}.json") examples = [{"prompt": p, "predicted": h, "ground_truth": r} for p, h, r in zip(out["prompts"], out["hypotheses"], out["references"])] @@ -60,20 +63,24 @@ def main(): del elm, elm_components, build_elm, gpu_setup, dataloader, build_dataloader gc.collect() torch.cuda.empty_cache() - if "confusion_matrix" in out: - cm_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}.png") - save_confusion_matrix_png(out["confusion_matrix"], cm_path) - other_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_other.png") - save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10) - incorrect_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_incorrect.png") - save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path) - if "pretrain_breakdown" in out: - save_pretrain_breakdown_pngs(out["pretrain_breakdown"], - results_file.replace(".json", f"{fold}_{seed}_pretrain")) - statistical_results = run_statistical_analysis(all_metrics) - with open(results_file, "w") as f: - json.dump(statistical_results, f, indent=2) - print(f"Saved evaluation results to {results_file}") + if is_main(): + if "confusion_matrix" in out: + cm_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}.png") + save_confusion_matrix_png(out["confusion_matrix"], cm_path) + other_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_other.png") + save_other_outputs_histogram_png(out["other_output_counts"], other_path, top_k = 10) + incorrect_path = results_file.replace(".json", f"{fold}_{seed}_{args.max_new_tokens}_incorrect.png") + save_incorrect_predictions_histogram_png(out["references"], out["hypotheses"], incorrect_path) + if "pretrain_breakdown" in out: + save_pretrain_breakdown_pngs(out["pretrain_breakdown"], + results_file.replace(".json", f"{fold}_{seed}_pretrain")) + if is_main(): + statistical_results = run_statistical_analysis(all_metrics) + with open(results_file, "w") as f: + json.dump(statistical_results, f, indent=2) + print(f"Saved evaluation results to {results_file}") + if args.distributed: + cleanup() if __name__ == "__main__": diff --git a/src/runners/evaluator.py b/src/runners/evaluator.py index 0df03c3..ce78a76 100644 --- a/src/runners/evaluator.py +++ b/src/runners/evaluator.py @@ -12,7 +12,9 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt -from utils.gpu_manager import is_main, train_dev_break, batch_to_device +import torch.distributed as dist + +from utils.gpu_manager import is_main, train_dev_break, batch_to_device, get_rank, get_world_size _THINK_RE = re.compile(r"(.*?)", re.DOTALL) _ANSWER_RE = re.compile(r"(.*?)", re.DOTALL) @@ -212,7 +214,68 @@ def run_statistical_analysis(all_seeds_results): return statistical_results def index_nested(encoder_tokenizer_out, batch): - return {k: index_nested(v, batch) if isinstance(v, dict) else v[batch:batch+1] for k, v in encoder_tokenizer_out.items()} + return {k: index_nested(v, batch) if isinstance(v, dict) else v[batch] for k, v in encoder_tokenizer_out.items()} + + +def stack_nested(items): + return {k: stack_nested([item[k] for item in items]) if isinstance(items[0][k], dict) + else torch.stack([item[k] for item in items], dim=0) for k in items[0]} + + +def flatten_eval_turns(dataloader, args, needs_signal_injection): + """Expand every (sample, turn) pair into a flat generation work item.""" + dataset = dataloader.dataset + turns = [] + progress = tqdm(dataloader, desc="Flattening turns", disable=not is_main(), leave=False) + for batch_idx, batch in enumerate(progress): + B = batch["elm_input_ids"].shape[0] + for b in range(B): + full_ids = batch["elm_input_ids"][b].tolist() + full_attn = batch["elm_attention_mask"][b].tolist() + if needs_signal_injection: + signal_indices = batch["signal_id_indices"][b] + encoder_out = index_nested(batch["encoder_tokenizer_out"], b) + ranges = dataset.get_response_ranges(full_ids) + gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) + if getattr(args, "dev", False): + print(f"\n--- Batch {batch_idx}, Sample {b} ---") + print(f"Total turns: {len(ranges)}") + dataset.assert_range_alignment(full_ids, ranges) + for (s, _), gt in zip(ranges, gt_texts): + turn = {"order": len(turns), "prefix_ids": full_ids[:s], + "prefix_attn": full_attn[:s], "gt_text": gt} + if needs_signal_injection: + masked_indices = signal_indices.clone() + masked_indices[masked_indices >= s] = -1 + turn["signal_id_indices"] = masked_indices + turn["encoder_tokenizer_out"] = encoder_out + turns.append(turn) + if train_dev_break(getattr(args, "dev", False), batch, 0): + break + return turns + + +def collate_turns(chunk, pad_token_id): + """Left-pad a chunk of turn items to its max prefix length for generation.""" + max_len = max(len(turn["prefix_ids"]) for turn in chunk) + input_ids, attention_mask = [], [] + for turn in chunk: + pad = max_len - len(turn["prefix_ids"]) + input_ids.append([pad_token_id] * pad + turn["prefix_ids"]) + attention_mask.append([0] * pad + turn["prefix_attn"]) + gen_batch = { + "elm_input_ids": torch.tensor(input_ids, dtype=torch.int64), + "elm_attention_mask": torch.tensor(attention_mask, dtype=torch.float32), + } + if "signal_id_indices" in chunk[0]: + shifted = [] + for turn in chunk: + pad = max_len - len(turn["prefix_ids"]) + indices = turn["signal_id_indices"] + shifted.append(torch.where(indices >= 0, indices + pad, indices)) + gen_batch["signal_id_indices"] = torch.stack(shifted, dim=0) + gen_batch["encoder_tokenizer_out"] = stack_nested([turn["encoder_tokenizer_out"] for turn in chunk]) + return gen_batch def pretrain_diagnostic_breakdown(refs, hyps): split = lambda s: {x.strip() for x in (s or "").split(";") if x.strip()} @@ -299,77 +362,64 @@ def save_incorrect_predictions_histogram_png(references, hypotheses, path, top_k print(f"Saved incorrect-predictions histogram to {path}") def evaluate(elm, dataloader, args): - show_progress = is_main() elm.eval() needs_signal_injection = args.elm in ("mlp_llava", "linear_llava", "base_elf", "patch_elf", "conv_elf") - progress = tqdm( - dataloader, - desc=f"LLM: {args.llm} ENCODER: {args.encoder}", - disable=not show_progress, - leave=False, - ) dataset = dataloader.dataset device = next(elm.parameters()).device - all_refs, all_hyps, all_prompts = [], [], [] + + turns = flatten_eval_turns(dataloader, args, needs_signal_injection) + # Every rank flattens the full dataset identically; each generates a + # disjoint stride of the turn list and results are gathered afterwards. + world_size = get_world_size() + local_turns = turns[get_rank()::world_size] if world_size > 1 else turns + gen_model = elm.module if hasattr(elm, "module") else elm # unwrap DDP for generate() + + eval_batch_size = getattr(args, "eval_batch_size", 1) + pad_token_id = dataset.llm_tokenizer.pad_token_id + results = [] # (order, gt_text, gen_txt, prefix_ids) + progress = tqdm(range(0, len(local_turns), eval_batch_size), + desc=f"LLM: {args.llm} ENCODER: {args.encoder} (eval_bs={eval_batch_size})", + disable=not is_main(), leave=False) with torch.no_grad(): - for batch_idx, batch in enumerate(progress): - B = batch["elm_input_ids"].shape[0] - for b in range(B): - full_ids = batch["elm_input_ids"][b].tolist() - full_attn = batch["elm_attention_mask"][b].tolist() - if needs_signal_injection: - signal_indices = batch["signal_id_indices"][b] - full_encoder_tokenizer_out = index_nested(batch["encoder_tokenizer_out"], b) - ranges = dataset.get_response_ranges(full_ids) - gt_texts = dataset.get_ground_truth_responses(full_ids, ranges) - if getattr(args, "dev", False): - print(f"\n--- Batch {batch_idx}, Sample {b} ---") - print(f"Total turns: {len(ranges)}") - dataset.assert_range_alignment(full_ids, ranges) - for turn_idx, ((s, _), gt) in enumerate(zip(ranges, gt_texts)): - sub_ids = full_ids[:s] - sub_attn = full_attn[:s] - gen_batch = { - "elm_input_ids": torch.tensor(sub_ids, dtype=torch.int64).unsqueeze(0), - "elm_attention_mask": torch.tensor(sub_attn, dtype=torch.float32).unsqueeze(0), - "max_new_tokens": args.max_new_tokens - } - if needs_signal_injection: - gen_batch["encoder_tokenizer_out"] = full_encoder_tokenizer_out - truncated_len = len(sub_ids) - masked_indices = signal_indices.clone() - masked_indices[masked_indices >= truncated_len] = -1 - gen_batch["signal_id_indices"] = masked_indices - gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} - gen_out = elm.generate(**gen_batch)[0].cpu().tolist() - gen_txt = dataset.get_generated_response_for_turn(sub_ids, gen_out) - if getattr(args, "dev", False): - print(f"\nTurn {turn_idx + 1}:") - print(f"\nGround Truth:\n{gt}") - print(f"\nGenerated:\n{gen_txt}") - print("-" * 100) - if gt and gen_txt: - all_prompts.append(dataset.llm_tokenizer.decode(sub_ids, skip_special_tokens=True).strip()) - all_refs.append(gt) - all_hyps.append(gen_txt) - if train_dev_break(getattr(args, "dev", False), batch, 0): - break - # if batch_idx == 10: - # break - # input() + for start in progress: + chunk = local_turns[start:start + eval_batch_size] + gen_batch = collate_turns(chunk, pad_token_id) + gen_batch = {k: batch_to_device(v, device) for k, v in gen_batch.items()} + gen_out = gen_model.generate(**gen_batch, max_new_tokens=args.max_new_tokens) + for turn, row in zip(chunk, gen_out): + gen_txt = dataset.get_generated_response_for_turn(turn["prefix_ids"], row.cpu().tolist()) + if getattr(args, "dev", False) and is_main(): + print(f"\nTurn (order {turn['order']}):") + print(f"\nGround Truth:\n{turn['gt_text']}") + print(f"\nGenerated:\n{gen_txt}") + print("-" * 100) + results.append((turn["order"], turn["gt_text"], gen_txt, turn["prefix_ids"])) + + if world_size > 1: + gathered = [None] * world_size + dist.all_gather_object(gathered, results) + results = [r for shard in gathered for r in shard] + results.sort(key=lambda r: r[0]) + all_refs, all_hyps, all_prompts = [], [], [] + for _, gt, gen_txt, prefix_ids in results: + if gt and gen_txt: + all_prompts.append(dataset.llm_tokenizer.decode(prefix_ids, skip_special_tokens=True).strip()) + all_refs.append(gt) + all_hyps.append(gen_txt) refs_t, refs_a = map(list, zip(*map(split_response, all_refs))) if all_refs else ([], []) hyps_t, hyps_a = map(list, zip(*map(split_response, all_hyps))) if all_hyps else ([], []) think_pairs = [(r, h) for r, h in zip(refs_t, hyps_t) if r and h] results = {"answer": evaluate_strings(refs_a, hyps_a)} if think_pairs: results["thinking"] = evaluate_strings(*map(list, zip(*think_pairs))) - print("\n=== N-Turn Evaluation (generated vs. gold response only) ===") - print(f"Pairs: {len(all_refs)} (thinking pairs: {len(think_pairs)})") - for group, mdict in results.items(): - print(f"[{group}]") - for k, v in mdict.items(): - print(f" {k}: {v:.4f}") + if is_main(): + print("\n=== N-Turn Evaluation (generated vs. gold response only) ===") + print(f"Pairs: {len(all_refs)} (thinking pairs: {len(think_pairs)})") + for group, mdict in results.items(): + print(f"[{group}]") + for k, v in mdict.items(): + print(f" {k}: {v:.4f}") out = { "num_pairs": len(all_refs), "metrics": results, @@ -380,13 +430,15 @@ def evaluate(elm, dataloader, args): if getattr(args, "train_phase", "sft") == "pretrain" and refs_a: breakdown = pretrain_diagnostic_breakdown(refs_a, hyps_a) out["pretrain_breakdown"] = breakdown - print(f"\n=== Pretrain diagnostic breakdown (N={breakdown['n']:,}) ===") - print(f" Matched={breakdown['matched']:,} Not matched={breakdown['not_matched']:,} Other={breakdown['other']:,}") - print(f" Missed_inst={breakdown['missed_inst']:,} Extra_inst={breakdown['extra_inst']:,}") - print(f" Only missed={breakdown['only_missed']:,} Only extra={breakdown['only_extra']:,} Both={breakdown['both']:,}") + if is_main(): + print(f"\n=== Pretrain diagnostic breakdown (N={breakdown['n']:,}) ===") + print(f" Matched={breakdown['matched']:,} Not matched={breakdown['not_matched']:,} Other={breakdown['other']:,}") + print(f" Missed_inst={breakdown['missed_inst']:,} Extra_inst={breakdown['extra_inst']:,}") + print(f" Only missed={breakdown['only_missed']:,} Only extra={breakdown['only_extra']:,} Both={breakdown['both']:,}") if any(d.startswith("ecg-comp") for d in args.data): per_class_acc, confusion_matrix, other_counts = compute_classification_metrics(refs_a, hyps_a) - print_classification_metrics(per_class_acc, confusion_matrix) + if is_main(): + print_classification_metrics(per_class_acc, confusion_matrix) results["per_class_acc"] = per_class_acc out["confusion_matrix"] = confusion_matrix out["other_output_counts"] = other_counts