diff --git a/src/eval_framework/response_generator.py b/src/eval_framework/response_generator.py index c8168eed..f8cc2171 100644 --- a/src/eval_framework/response_generator.py +++ b/src/eval_framework/response_generator.py @@ -78,7 +78,7 @@ def __init__(self, llm: BaseLLM, config: EvalConfig, result_processor: ResultsFi custom_hf_revision=self.config.hf_revision, ) - self.response_type, _ = self.task._get_type_and_metrics() + self.response_type = self.task.get_response_type() def _llm_task_param_precedence(self) -> tuple[list[str] | None, int | None]: """ diff --git a/src/eval_framework/tasks/base.py b/src/eval_framework/tasks/base.py index 16fb94b8..07c21909 100644 --- a/src/eval_framework/tasks/base.py +++ b/src/eval_framework/tasks/base.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict from eval_framework.shared.types import BaseMetricContext, Completion, Error, RawCompletion -from eval_framework.tasks.utils import raise_errors +from eval_framework.tasks.utils import classproperty, raise_errors from template_formatting.formatter import Message, Role if TYPE_CHECKING: @@ -91,8 +91,6 @@ class BaseTask[SubjectType](ABC): DATASET_PATH: str SAMPLE_SPLIT: str FEWSHOT_SPLIT: str - RESPONSE_TYPE: ResponseType - METRICS: list[type["BaseMetric"]] SUBJECTS: list[SubjectType] HF_REVISION: str | None = None # tag name, or branch name, or commit hash to ensure reproducibility @@ -104,6 +102,10 @@ class BaseTask[SubjectType](ABC): # language by subtopic, or `None` (for tasks not specific to a single language). LANGUAGE: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None + # RESPONSE_TYPE and METRICS use exposed as classproperties, so you can access them via either + # `TaskClass.*` or `task.*` (or `task.get_metrics()`). This avoids mypy conflicts from re-declaring class vars. + # By default, these values come from TASK_STYLER if set, otherwise from legacy class attributes. + def __init__(self, num_fewshot: int = 0) -> None: self.num_fewshot = num_fewshot self.stop_sequences: list[str] | None = None @@ -332,14 +334,12 @@ def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMet return None def get_metadata(self) -> dict[str, str | list[str]]: - response_type, metrics = self._get_type_and_metrics() - meta: dict[str, str | list[str]] = { "dataset_path": self.DATASET_PATH, "sample_split": self.SAMPLE_SPLIT, "fewshot_split": self.FEWSHOT_SPLIT, - "response_type": response_type.value, - "metrics": [m.NAME for m in metrics], + "response_type": self.get_response_type().value, + "metrics": [m.NAME for m in self.get_metrics()], "subjects": [str(s) for s in self.SUBJECTS], } if hasattr(self, "TASK_STYLER"): @@ -420,7 +420,26 @@ def generate_completions( ) return completion_list - def _get_type_and_metrics(self) -> tuple[ResponseType, list[type["BaseMetric"]]]: - if hasattr(self, "TASK_STYLER"): - return self.TASK_STYLER.response_type, self.TASK_STYLER.metrics - return self.RESPONSE_TYPE, self.METRICS + @classmethod + def get_response_type(cls) -> ResponseType: + """Return the response type of the task (or the styler if it exists).""" + if hasattr(cls, "TASK_STYLER"): + return cls.TASK_STYLER.response_type + return cls.RESPONSE_TYPE + + @classmethod + def get_metrics(cls) -> list[type["BaseMetric"]]: + """Return the metrics of the task (or the styler if it exists).""" + if hasattr(cls, "TASK_STYLER"): + return cls.TASK_STYLER.metrics + return cls.METRICS + + @classproperty + def RESPONSE_TYPE(cls) -> ResponseType: + """For backwards compatibility.""" + return cls.get_response_type() + + @classproperty + def METRICS(cls) -> list[type["BaseMetric"]]: + """For backwards compatibility.""" + return cls.get_metrics() diff --git a/src/eval_framework/tasks/eval_config.py b/src/eval_framework/tasks/eval_config.py index 58bc2b4d..99f88a6b 100644 --- a/src/eval_framework/tasks/eval_config.py +++ b/src/eval_framework/tasks/eval_config.py @@ -112,7 +112,7 @@ def validate_judge_model_args(cls, value: dict[str, Any]) -> dict[str, Any]: @model_validator(mode="after") def validate_llm_judge_defined(self) -> "EvalConfig": task = get_task(self.task_name) - _, task_metrics = task(num_fewshot=0)._get_type_and_metrics() + task_metrics = task(num_fewshot=0).get_metrics() for metric_class in task_metrics: if issubclass(metric_class, BaseLLMJudgeMetric): assert self.llm_judge_class is not None, "The LLM Judge must be defined for this evaluation task." diff --git a/src/eval_framework/tasks/utils.py b/src/eval_framework/tasks/utils.py index 5752f1e7..d414444d 100644 --- a/src/eval_framework/tasks/utils.py +++ b/src/eval_framework/tasks/utils.py @@ -8,7 +8,7 @@ import threading from collections.abc import Callable from pathlib import Path -from typing import Any, Literal, NamedTuple +from typing import Any, Literal, NamedTuple, overload import dill import numpy as np @@ -22,6 +22,24 @@ RANDOM_SEED = 42 # hacky way to get around circular import redis_warning_printed = False + +class classproperty[T]: + """Descriptor supporting property-like access on classes and instances.""" + + def __init__(self, fget: Callable[[Any], T]) -> None: + self.fget = fget + + @overload + def __get__(self, obj: None, owner: type[Any]) -> T: ... + + @overload + def __get__(self, obj: object, owner: type[Any] | None = None) -> T: ... + + def __get__(self, obj: object | None, owner: type[Any] | None = None) -> T: + cls = owner if owner is not None else type(obj) + return self.fget(cls) + + _pools: dict[tuple[str | None, tuple[str, ...] | None], ContainerPoolManager] = {} _pools_lock = threading.Lock() diff --git a/tests/tests_eval_framework/tasks/test_task_style.py b/tests/tests_eval_framework/tasks/test_task_style.py index 5951fbdb..c33c5070 100644 --- a/tests/tests_eval_framework/tasks/test_task_style.py +++ b/tests/tests_eval_framework/tasks/test_task_style.py @@ -506,3 +506,11 @@ def test_metadata_task_style(self) -> None: def test_metadata_metrics_bpb_only(self) -> None: meta = self.task.get_metadata() assert meta["metrics"] == ["BitsPerByte"] + + +def test_instance_properties_are_styler_backed() -> None: + task = _ConcreteMCTask() + + # Check compatibility access points for metadata. + assert task.RESPONSE_TYPE == ResponseType.LOGLIKELIHOODS + assert task.METRICS == task.TASK_STYLER.metrics