From c21a3ba9c4a9272012ac5ecf563d62a9569aab99 Mon Sep 17 00:00:00 2001 From: David Salinas Date: Fri, 17 Apr 2026 17:14:23 +0200 Subject: [PATCH 1/3] message in case of stale datasets (#59) * message in case of stale datasets * fix rugg --- oellm/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/oellm/utils.py b/oellm/utils.py index faf1bd1..4b6830c 100644 --- a/oellm/utils.py +++ b/oellm/utils.py @@ -343,6 +343,18 @@ def _pre_download_datasets_from_specs( trust_remote_code=trust_remote_code, ) continue + if "Feature type" in str(e) and "not found" in str(e): + hf_datasets_cache = os.environ.get( + "HF_DATASETS_CACHE", + str(Path.home() / ".cache" / "huggingface" / "datasets"), + ) + safe_name = spec.repo_id.replace("/", "___") + cache_dir = os.path.join(hf_datasets_cache, safe_name) + raise RuntimeError( + f"Cached metadata for '{label}' is incompatible with the installed " + f"datasets version ('{e}'). Delete the stale cache and re-run:\n\n" + f" rm -rf {cache_dir}\n" + ) from None raise logging.debug(f"Finished downloading dataset '{label}'.") From cd465d40fac99540f0cfaf8bc982f154f3bcc0ea Mon Sep 17 00:00:00 2001 From: Ting-Wen Ko <55427790+kerkathy@users.noreply.github.com> Date: Tue, 21 Apr 2026 11:04:33 +0200 Subject: [PATCH 2/3] =?UTF-8?q?Enable=20belebele=5Fcf=20on=20Leonardo=20+?= =?UTF-8?q?=20SLURM=20memory=20request=20+=20lighteval=20batch=E2=80=91siz?= =?UTF-8?q?e=20defaults=20+=20enable=20`collect-results`=20for=20belebele?= =?UTF-8?q?=5Fcf=20(#60)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add belebele_cf * Add slurm request, lighteval batch size control and defaults (1 for local, 32 for cluster) * In oellm-collect, add feat to search for nested json results, and fix n_shot extraction * fix lint err * Fix metric for belebele_cf to use acc_norm * Clean up slurm mem request logic * Clean [debug] from output * Prettify outputted gpu and lighteval msg, change variable names passed into lighteval's batch_size and model_args * Fix formatting * Fix incompatibility bug on non-gpu machine --------- Co-authored-by: David Salinas --- README.md | 32 +++++++- oellm/main.py | 122 +++++++++++++++++++++++++------ oellm/resources/task-groups.yaml | 85 +++++++++++++++++++++ oellm/resources/template.sbatch | 10 ++- oellm/task_groups.py | 18 +++++ 5 files changed, 239 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index a92f932..dd9d43e 100644 --- a/README.md +++ b/README.md @@ -99,19 +99,43 @@ The `HF_HUB_OFFLINE` value is read when you invoke `oellm` and baked into the ge ## SLURM Overrides -Override cluster defaults (partition, account, time limit, etc.) with `--slurm_template_var` (JSON object): +Override cluster defaults (partition, account, time limit, memory, etc.) with `--slurm_template_var` (JSON object). Provide `SLURM_MEM` to request an exact host memory amount, otherwise falls back to a default of `96G`. ```bash # Use a different partition (e.g. dev-g on LUMI when small-g is crowded) oellm schedule-eval --models "model-name" --task_groups "open-sci-0.01" \ --slurm_template_var '{"PARTITION":"dev-g"}' -# Multiple overrides: partition, account, time limit, GPUs +# Multiple overrides: partition, account, time limit, GPUs, exact RAM oellm schedule-eval --models "model-name" --task_groups "open-sci-0.01" \ - --slurm_template_var '{"PARTITION":"dev-g","ACCOUNT":"myproject","TIME":"02:00:00","GPUS_PER_NODE":2}' + --slurm_template_var '{"PARTITION":"dev-g","ACCOUNT":"myproject","TIME":"02:00:00","GPUS_PER_NODE":2,"SLURM_MEM":"96G"}' ``` -Use exact env var names: `PARTITION`, `ACCOUNT`, `GPUS_PER_NODE`. `TIME` (HH:MM:SS) overrides the time limit. +Use exact env var names: `PARTITION`, `ACCOUNT`, `GPUS_PER_NODE`, `SLURM_MEM`. `TIME` (HH:MM:SS) overrides the time limit. + +## Lighteval Batch Size + +For lighteval runs, generated jobs default to `batch_size=1` for local runs and +`batch_size=32` for non-local (SLURM/cluster) runs. This reduces the risk of +out-of-memory failures where lighteval's auto batch-size detection can be +overly optimistic for multiple-choice loglikelihood tasks. You can still +override these defaults: + +```bash +# Set an explicit batch size (overrides the local/cluster default) +BATCH_SIZE=8 oellm schedule-eval \ + --models "model-name" \ + --task_groups "belebele-eu-cf" \ + --venv_path .venv +``` + +If you need full manual control over all model args, set `MODEL_ARGS`, +for example: + +```bash +MODEL_ARGS='batch_size=8' oellm schedule-eval \ + --models "model-name" --task_groups "belebele-eu-cf" --venv_path .venv +``` ## ⚠️ Dataset Pre-Download Warning diff --git a/oellm/main.py b/oellm/main.py index c32baec..6a19984 100644 --- a/oellm/main.py +++ b/oellm/main.py @@ -14,6 +14,7 @@ from jsonargparse import auto_cli from oellm.task_groups import ( + _build_task_suite_map, _collect_dataset_specs, _expand_task_groups, _lookup_dataset_specs_for_tasks, @@ -47,6 +48,50 @@ def _resolve_hf_hub_offline(local: bool) -> int: return 0 if local else 1 +def _resolve_slurm_mem() -> str: + """Return the host-memory request for the generated SLURM job.""" + explicit_mem = os.environ.get("SLURM_MEM") + if explicit_mem is not None and str(explicit_mem).strip() != "": + return str(explicit_mem).strip() + + logging.warning("SLURM_MEM not set; falling back to default memory request '96G'.") + return "96G" + + +def _resolve_additional_model_args(local: bool = False) -> str: + """Return model args for lighteval, defaulting to an explicit batch size. + - if `local` is True: `batch_size=1` + - otherwise: `batch_size=32` + + Users may override the entire model args via `MODEL_ARGS` or the + batch size via `BATCH_SIZE` environment variables. + + For now this is only passed to suite lighteval, not to evalchemy and lm-eval yet. + """ + explicit_model_args = os.environ.get("MODEL_ARGS") + if explicit_model_args is not None and str(explicit_model_args).strip() != "": + return str(explicit_model_args).strip() + + batch_size = os.environ.get("BATCH_SIZE") + if batch_size is not None and str(batch_size).strip() != "": + batch_size_value = str(batch_size).strip() + try: + if int(batch_size_value) < 1: + raise ValueError + except ValueError: + fallback = "1" if local else "32" + logging.warning( + "Invalid BATCH_SIZE=%r; falling back to batch_size=%s", + batch_size, + fallback, + ) + batch_size_value = fallback + else: + batch_size_value = "1" if local else "32" + + return f"batch_size={batch_size_value}" + + @dataclass class EvaluationJob: model_path: Path | str @@ -113,8 +158,8 @@ def schedule_evals( submitting to SLURM. Requires --venv_path. Skips cluster environment detection and runs all evaluations sequentially in a single process. slurm_template_var: JSON object of template variable overrides. Use exact env var names - (PARTITION, ACCOUNT, GPUS_PER_NODE). "TIME" overrides the time limit. - Example: '{"PARTITION":"dev-g","ACCOUNT":"FOO","TIME":"02:00:00","GPUS_PER_NODE":2}' + (PARTITION, ACCOUNT, GPUS_PER_NODE, SLURM_MEM). "TIME" overrides the time limit. + Example: '{"PARTITION":"dev-g","ACCOUNT":"FOO","TIME":"02:00:00","GPUS_PER_NODE":2,"SLURM_MEM":"96G"}' """ _setup_logging(verbose) @@ -191,13 +236,14 @@ def schedule_evals( elif models: if task_groups is None: + task_suite_map = _build_task_suite_map() eval_jobs.extend( [ EvaluationJob( model_path=model, task_path=task, n_shot=shot, - eval_suite="lm_eval", + eval_suite=task_suite_map.get(task, "lm_eval"), ) for model in models for task in tasks @@ -356,6 +402,7 @@ def schedule_evals( logging.info(f"Using slurm_template_var override: {key}={value}") # Log the calculated values + slurm_mem = _resolve_slurm_mem() logging.info("📊 Evaluation planning:") logging.info(f" Total evaluations: {total_evals}") logging.info(f" Estimated time per eval: {minutes_per_eval} minutes") @@ -371,6 +418,7 @@ def schedule_evals( f" Time per job: {minutes_per_job} minutes ({minutes_per_job / 60:.1f} hours)" ) logging.info(f" Time limit with safety margin: {time_limit}") + logging.info(f" Requested host memory: {slurm_mem}") sbatch_script = sbatch_template.format( csv_path=csv_path, @@ -381,14 +429,13 @@ def schedule_evals( log_dir=evals_dir / "slurm_logs", evals_dir=str(evals_dir / "results"), time_limit=time_limit, # Dynamic time limit + slurm_mem=slurm_mem, limit=limit if limit else "", # Sample limit for quick testing venv_path=venv_path or "", lm_eval_include_path=lm_eval_include_path or str(files("oellm.resources") / "custom_lm_eval_tasks"), hf_hub_offline=_resolve_hf_hub_offline(local), - lighteval_model_args="trust_remote_code=True,batch_size=1" - if local - else "trust_remote_code=True", + additional_model_args=_resolve_additional_model_args(local), # Batch size evalchemy_dir=os.environ.get("EVALCHEMY_DIR", "/opt/evalchemy"), ) @@ -523,7 +570,7 @@ def _first_matching_prefix( val, key = _first_matching_prefix(result_dict, preferred) return val, key - for metric in ["acc,none", "acc", "accuracy", "f1", "exact_match"]: + for metric in ["acc,none", "acc", "accuracy", "acc_norm", "f1", "exact_match"]: val, key = _first_numeric(result_dict, metric) if val is not None: return val, key @@ -532,14 +579,24 @@ def _first_matching_prefix( return val, key return None, None + def _split_task_and_nshot(name: str) -> tuple[str, int | None]: + """Split task names of the form 'task|N' returning (task, N) or (task, None).""" + if not isinstance(name, str): + return name, None + if "|" in name: + base, after = name.rsplit("|", 1) + if after.isdigit(): + return base, int(after) + return name, None + results_path = Path(results_dir) if not results_path.exists(): raise ValueError(f"Results directory does not exist: {results_dir}") # Check if we need to look in a 'results' subdirectory if (results_path / "results").exists() and (results_path / "results").is_dir(): - # User passed the top-level directory, look in results subdirectory - json_files = list((results_path / "results").glob("*.json")) + # User passed the top-level directory, look in results subdirectory for nested json files + json_files = list((results_path / "results").rglob("*.json")) else: # User passed the results directory directly json_files = list(results_path.glob("*.json")) @@ -569,8 +626,17 @@ def _first_matching_prefix( with open(json_file) as f: data = json.load(f) - # Extract model name/path - model_name = data.get("model_name", "unknown") + # Extract model name/path from a few common locations used in different + # versions of the result JSON schema. + model_name = ( + data.get("model_name") + or data.get("config_general", {}).get("model_name") + or data.get("config_general", {}).get("model") + or data.get("config_general", {}).get("model_path") + or data.get("summary_general", {}).get("model") + or data.get("model") + or "unknown" + ) # Extract results for each task results = data.get("results", {}) @@ -603,14 +669,21 @@ def _first_matching_prefix( # Prefer only the first aggregate metric from groups (simplified) if groups_map: group_name, group_results = next(iter(groups_map.items())) - n_shot = n_shot_data.get(group_name, "unknown") + # Prefer original extraction from n_shot_data and subtasks, then + # global_n_shot; only fall back to parsing the group name. + orig_group_name = group_name + n_shot = n_shot_data.get(orig_group_name, "unknown") if n_shot == "unknown": - for subtask_name in group_subtasks_map.get(group_name, []): + for subtask_name in group_subtasks_map.get(orig_group_name, []): if subtask_name in n_shot_data: n_shot = n_shot_data[subtask_name] break if n_shot == "unknown" and global_n_shot is not None: n_shot = global_n_shot + # Fallback: parse possible '|N' suffix from group name + group_name, parsed_n = _split_task_and_nshot(orig_group_name) + if n_shot == "unknown" and parsed_n is not None: + n_shot = parsed_n performance, metric_name = _resolve_metric(group_name, group_results) if performance is not None: if check: @@ -643,12 +716,17 @@ def _first_matching_prefix( if task_name.startswith("global_mmlu_") and task_name.count("_") >= 4: continue - # Get n_shot for this task - n_shot = n_shot_data.get(task_name, "unknown") + # Get n_shot for this task. + # Prefer original extraction from `n_shot_data` and `global_n_shot`, + # and fall back to parsing a '|N' suffix in the task name. + task_name_clean, parsed_n = _split_task_and_nshot(task_name) + n_shot = ( + n_shot_data.get(task_name_clean) or global_n_shot or parsed_n or "unknown" + ) # If this is a group aggregate and n_shot is missing, derive from any subtask - if task_name in group_aggregate_names and n_shot == "unknown": - for subtask_name in group_subtasks_map.get(task_name, []): + if task_name_clean in group_aggregate_names and n_shot == "unknown": + for subtask_name in group_subtasks_map.get(task_name_clean, []): if subtask_name in n_shot_data: n_shot = n_shot_data[subtask_name] break @@ -656,7 +734,7 @@ def _first_matching_prefix( n_shot = global_n_shot # Special handling for MMLU aggregate - get n_shot from any MMLU subtask - if task_name == "mmlu" and n_shot == "unknown": + if task_name_clean == "mmlu" and n_shot == "unknown": for key, value in n_shot_data.items(): if key.startswith("mmlu_"): n_shot = value @@ -665,8 +743,8 @@ def _first_matching_prefix( n_shot = global_n_shot # Special handling for Global MMLU aggregates - get n_shot from subtasks - if task_name.startswith("global_mmlu_") and n_shot == "unknown": - prefix = f"{task_name}_" + if task_name_clean.startswith("global_mmlu_") and n_shot == "unknown": + prefix = f"{task_name_clean}_" for key, value in n_shot_data.items(): if key.startswith(prefix): n_shot = value @@ -680,12 +758,12 @@ def _first_matching_prefix( if performance is not None: # Track completed job for check mode if check: - completed_jobs.add((model_name, task_name, n_shot)) + completed_jobs.add((model_name, task_name_clean, n_shot)) rows.append( { "model_name": model_name, - "task": task_name, + "task": task_name_clean, "n_shot": n_shot, "performance": performance, "metric_name": metric_name if metric_name is not None else "", diff --git a/oellm/resources/task-groups.yaml b/oellm/resources/task-groups.yaml index 4d40ef4..0aa5d12 100644 --- a/oellm/resources/task-groups.yaml +++ b/oellm/resources/task-groups.yaml @@ -1,5 +1,31 @@ task_metrics: mmlu: acc + belebele_bul_Cyrl_cf: acc_norm + belebele_hrv_Latn_cf: acc_norm + belebele_ces_Latn_cf: acc_norm + belebele_dan_Latn_cf: acc_norm + belebele_nld_Latn_cf: acc_norm + belebele_eng_Latn_cf: acc_norm + belebele_est_Latn_cf: acc_norm + belebele_fin_Latn_cf: acc_norm + belebele_fra_Latn_cf: acc_norm + belebele_deu_Latn_cf: acc_norm + belebele_ell_Grek_cf: acc_norm + belebele_hun_Latn_cf: acc_norm + belebele_ita_Latn_cf: acc_norm + belebele_lvs_Latn_cf: acc_norm + belebele_lit_Latn_cf: acc_norm + belebele_mlt_Latn_cf: acc_norm + belebele_pol_Latn_cf: acc_norm + belebele_por_Latn_cf: acc_norm + belebele_ron_Latn_cf: acc_norm + belebele_slk_Latn_cf: acc_norm + belebele_slv_Latn_cf: acc_norm + belebele_spa_Latn_cf: acc_norm + belebele_swe_Latn_cf: acc_norm + belebele_nob_Latn_cf: acc_norm + belebele_eus_Latn_cf: acc_norm + belebele_cat_Latn_cf: acc_norm copa: acc lambada_openai: acc openbookqa: acc_norm @@ -122,6 +148,65 @@ task_groups: subset: swe_Latn - task: belebele_nob_Latn subset: nob_Latn + belebele-eu-cf: + description: "Belebele European language tasks (cloze formulation, lighteval)" + suite: lighteval + n_shots: [0] + dataset: facebook/belebele + tasks: + - task: belebele_bul_Cyrl_cf + subset: bul_Cyrl + - task: belebele_hrv_Latn_cf + subset: hrv_Latn + - task: belebele_ces_Latn_cf + subset: ces_Latn + - task: belebele_dan_Latn_cf + subset: dan_Latn + - task: belebele_nld_Latn_cf + subset: nld_Latn + - task: belebele_eng_Latn_cf + subset: eng_Latn + - task: belebele_est_Latn_cf + subset: est_Latn + - task: belebele_fin_Latn_cf + subset: fin_Latn + - task: belebele_fra_Latn_cf + subset: fra_Latn + - task: belebele_deu_Latn_cf + subset: deu_Latn + - task: belebele_ell_Grek_cf + subset: ell_Grek + - task: belebele_hun_Latn_cf + subset: hun_Latn + - task: belebele_ita_Latn_cf + subset: ita_Latn + - task: belebele_lvs_Latn_cf + subset: lvs_Latn + - task: belebele_lit_Latn_cf + subset: lit_Latn + - task: belebele_mlt_Latn_cf + subset: mlt_Latn + - task: belebele_pol_Latn_cf + subset: pol_Latn + - task: belebele_por_Latn_cf + subset: por_Latn + - task: belebele_ron_Latn_cf + subset: ron_Latn + - task: belebele_slk_Latn_cf + subset: slk_Latn + - task: belebele_slv_Latn_cf + subset: slv_Latn + - task: belebele_spa_Latn_cf + subset: spa_Latn + - task: belebele_swe_Latn_cf + subset: swe_Latn + - task: belebele_nob_Latn_cf + subset: nob_Latn + - task: belebele_eus_Latn_cf + subset: eus_Latn + - task: belebele_cat_Latn_cf + subset: cat_Latn + flores-200-eu-to-eng: description: "Flores 200 EU to English translation" suite: lighteval diff --git a/oellm/resources/template.sbatch b/oellm/resources/template.sbatch index 53995b4..38ecfa2 100644 --- a/oellm/resources/template.sbatch +++ b/oellm/resources/template.sbatch @@ -2,6 +2,7 @@ #SBATCH --job-name=oellm-eval #SBATCH --time={time_limit} #SBATCH --gres=gpu:$GPUS_PER_NODE +#SBATCH --mem={slurm_mem} #SBATCH --output={log_dir}/%x-%A-%a.out #SBATCH --partition=$PARTITION #SBATCH --account=$ACCOUNT @@ -129,6 +130,9 @@ do case "$suite_normalized" in lm_eval|lm-eval|lm-eval-harness) + echo + echo "----------------------------------------------------" + echo "lm_eval Execution" run_python -m lm_eval --model hf \ --model_args pretrained="$model_path",trust_remote_code=True \ --tasks "$task_path" \ @@ -137,6 +141,7 @@ do --trust_remote_code \ ${{LM_EVAL_INCLUDE_PATH:+--include_path $LM_EVAL_INCLUDE_PATH}} \ ${{LIMIT:+--limit $LIMIT}} + echo "----------------------------------------------------" ;; lighteval|light-eval) LIGHT_TASK="$task_path" @@ -162,7 +167,7 @@ do if [ -n "$VENV_PATH" ]; then source "$VENV_PATH/bin/activate" lighteval accelerate \ - "model_name=$model_path,{lighteval_model_args}" \ + "model_name=$model_path,trust_remote_code=True,{additional_model_args}" \ "$LIGHT_TASK_ARG" \ --load-tasks-multilingual \ --output-dir "$RESULTS_SUBDIR" \ @@ -175,7 +180,7 @@ do $EVAL_SIF_PATH \ env CUDA_VISIBLE_DEVICES=$GPU_DEVICES \ lighteval accelerate \ - "model_name=$model_path,{lighteval_model_args}" \ + "model_name=$model_path,{additional_model_args}" \ "$LIGHT_TASK_ARG" \ --load-tasks-multilingual \ --output-dir "$RESULTS_SUBDIR" \ @@ -206,6 +211,7 @@ do ;; esac + echo "----------------------------------------------------" echo "Evaluation finished for model: $model_path" done diff --git a/oellm/task_groups.py b/oellm/task_groups.py index 7293013..2ab1dc8 100644 --- a/oellm/task_groups.py +++ b/oellm/task_groups.py @@ -266,6 +266,24 @@ def _lookup_dataset_specs_for_tasks(task_names: Iterable[str]) -> list[DatasetSp return specs +def _build_task_suite_map() -> dict[str, str]: + """Build a mapping from task names to their suite from all task groups.""" + data = ( + yaml.safe_load((files("oellm.resources") / "task-groups.yaml").read_text()) or {} + ) + + task_suite_map: dict[str, str] = {} + for _, group_data in data.get("task_groups", {}).items(): + group_suite = group_data.get("suite", "lm-eval-harness") + for task_data in group_data.get("tasks", []): + task_name = task_data.get("task") + task_suite = task_data.get("suite", group_suite) + if task_name and task_name not in task_suite_map: + task_suite_map[task_name] = task_suite + + return task_suite_map + + def get_all_task_group_names() -> list[str]: """Return all available task group names (excluding super_groups).""" data = ( From fdfdc4962e2773d61c2d109ef93dde2a07dbb760 Mon Sep 17 00:00:00 2001 From: Ivan Slobozhan Date: Wed, 22 Apr 2026 16:49:12 +0200 Subject: [PATCH 3/3] clean up --- oellm/scheduler.py | 8 +++- oellm/task_groups.py | 24 +++++----- tests/test_task_suite_map.py | 88 ++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 13 deletions(-) create mode 100644 tests/test_task_suite_map.py diff --git a/oellm/scheduler.py b/oellm/scheduler.py index 979fb71..af203c9 100644 --- a/oellm/scheduler.py +++ b/oellm/scheduler.py @@ -14,6 +14,7 @@ from oellm.constants import EvaluationJob from oellm.runner import EvalRunner from oellm.task_groups import ( + _build_task_suite_map, _collect_dataset_specs, _collect_hf_dataset_files, _collect_hf_model_repos, @@ -233,13 +234,18 @@ def schedule_evals( elif models: if group_names is None: + # Look up each bare task name in the registered groups so + # ``--tasks belebele_eng_Latn_cf`` (lighteval) or ``--tasks + # regiondial_refcocog_all`` (contrib) get routed correctly. + # Tasks not in any group default to lm_eval. + task_suite_map = _build_task_suite_map() eval_jobs.extend( [ EvaluationJob( model_path=model, task_path=task, n_shot=shot, - eval_suite="lm_eval", + eval_suite=task_suite_map.get(task, "lm_eval"), ) for model in models for task in tasks diff --git a/oellm/task_groups.py b/oellm/task_groups.py index 5793a19..d930532 100644 --- a/oellm/task_groups.py +++ b/oellm/task_groups.py @@ -324,20 +324,20 @@ def _lookup_dataset_specs_for_tasks(task_names: Iterable[str]) -> list[DatasetSp def _build_task_suite_map() -> dict[str, str]: - """Build a mapping from task names to their suite from all task groups.""" - data = ( - yaml.safe_load((files("oellm.resources") / "task-groups.yaml").read_text()) or {} - ) + """Return ``{task_name: eval_suite}`` across core YAML and contrib plugins. - task_suite_map: dict[str, str] = {} - for _, group_data in data.get("task_groups", {}).items(): - group_suite = group_data.get("suite", "lm-eval-harness") - for task_data in group_data.get("tasks", []): - task_name = task_data.get("task") - task_suite = task_data.get("suite", group_suite) - if task_name and task_name not in task_suite_map: - task_suite_map[task_name] = task_suite + Uses :func:`_parse_task_groups` + :func:`_iter_all_tasks` so contrib + registries (e.g. ``regiondial_bench``) are included, not just the core + ``task-groups.yaml``. Task-level ``suite`` overrides group-level. First + occurrence wins when a task name appears in multiple groups. + Consumers should still ``.get(task, "lm_eval")`` — tasks not registered + in any group simply aren't in the map. + """ + parsed = _parse_task_groups(get_all_task_group_names()) + task_suite_map: dict[str, str] = {} + for t, suite, _group in _iter_all_tasks(parsed): + task_suite_map.setdefault(t.name, suite) return task_suite_map diff --git a/tests/test_task_suite_map.py b/tests/test_task_suite_map.py new file mode 100644 index 0000000..c8c6be1 --- /dev/null +++ b/tests/test_task_suite_map.py @@ -0,0 +1,88 @@ +"""Tests for :func:`oellm.task_groups._build_task_suite_map`. + +The helper powers the ``--tasks`` (bare-task-name) path in the scheduler. +It must cover every suite we actually support — core YAML-registered suites +(lm-eval-harness, lighteval, lmms_eval, evalchemy) AND contrib-registered +suites (e.g. regiondial_bench). +""" + +from __future__ import annotations + +from oellm.task_groups import _build_task_suite_map + + +def test_map_is_non_empty(): + m = _build_task_suite_map() + assert len(m) > 0, "suite map must contain at least core YAML tasks" + + +def test_map_includes_lm_eval_harness_task(): + m = _build_task_suite_map() + # copa is a classic lm-eval-harness task in task-groups.yaml + assert m.get("copa") == "lm-eval-harness" + + +def test_map_includes_lighteval_task(): + m = _build_task_suite_map() + # belebele_*_cf tasks are lighteval + assert m.get("belebele_eng_Latn_cf") == "lighteval" + + +def test_map_includes_lmms_eval_task(): + """lmms_eval tasks come from image/video task groups — must be routable.""" + m = _build_task_suite_map() + # vqav2_val is the base VQA v2 task (image modality) + assert m.get("vqav2_val") == "lmms_eval" + + +def test_map_includes_contrib_task(): + """Contrib plugins (e.g. regiondial_bench) register their own TASK_GROUPS. + + These are the regression target: the original upstream helper only read + YAML and missed contrib entirely. + """ + m = _build_task_suite_map() + assert m.get("regiondial_refcocog") == "regiondial_bench" + + +def test_map_honours_task_level_suite_override(): + """Evalchemy tasks set ``suite: evalchemy`` at the task level, not the + group level — the helper must prefer the task-level value. + """ + m = _build_task_suite_map() + assert m.get("GPQADiamond") == "evalchemy" + + +def test_map_covers_all_actually_registered_suites(): + """Sanity: every distinct suite we see should be one we actually route. + + Guards against a new suite slipping into YAML or contrib without us + adding a case branch in template.sbatch (the ``*)`` catch-all routes + everything unknown to the contrib dispatcher, but we still want this + assertion as documentation). + """ + m = _build_task_suite_map() + distinct_suites = set(m.values()) + expected_subset = { + "lm-eval-harness", + "lighteval", + "lmms_eval", + "evalchemy", + "regiondial_bench", + } + # All expected suites must be present. Extra contrib suites are fine. + assert expected_subset.issubset(distinct_suites), ( + f"missing suites: {expected_subset - distinct_suites}" + ) + + +def test_first_occurrence_wins_when_task_in_multiple_groups(): + """If a task name appears in multiple groups, first occurrence wins. + + This is documented behavior of ``setdefault`` in the helper. We don't + assert a specific pair here because the YAML contents shift; we only + assert the determinism property. + """ + m1 = _build_task_suite_map() + m2 = _build_task_suite_map() + assert m1 == m2