Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/eval_framework/response_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, llm: BaseLLM, config: EvalConfig, result_processor: ResultsFi
custom_hf_revision=self.config.hf_revision,
)

self.response_type, _ = self.task._get_type_and_metrics()
self.response_type = self.task.get_response_type()

def _llm_task_param_precedence(self) -> tuple[list[str] | None, int | None]:
"""
Expand Down
41 changes: 30 additions & 11 deletions src/eval_framework/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel, ConfigDict

from eval_framework.shared.types import BaseMetricContext, Completion, Error, RawCompletion
from eval_framework.tasks.utils import raise_errors
from eval_framework.tasks.utils import classproperty, raise_errors
from template_formatting.formatter import Message, Role

if TYPE_CHECKING:
Expand Down Expand Up @@ -91,8 +91,6 @@ class BaseTask[SubjectType](ABC):
DATASET_PATH: str
SAMPLE_SPLIT: str
FEWSHOT_SPLIT: str
RESPONSE_TYPE: ResponseType
METRICS: list[type["BaseMetric"]]
SUBJECTS: list[SubjectType]
HF_REVISION: str | None = None # tag name, or branch name, or commit hash to ensure reproducibility

Expand All @@ -104,6 +102,10 @@ class BaseTask[SubjectType](ABC):
# language by subtopic, or `None` (for tasks not specific to a single language).
LANGUAGE: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None

# RESPONSE_TYPE and METRICS use exposed as classproperties, so you can access them via either
# `TaskClass.*` or `task.*` (or `task.get_metrics()`). This avoids mypy conflicts from re-declaring class vars.
# By default, these values come from TASK_STYLER if set, otherwise from legacy class attributes.

def __init__(self, num_fewshot: int = 0) -> None:
self.num_fewshot = num_fewshot
self.stop_sequences: list[str] | None = None
Expand Down Expand Up @@ -332,14 +334,12 @@ def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMet
return None

def get_metadata(self) -> dict[str, str | list[str]]:
response_type, metrics = self._get_type_and_metrics()

meta: dict[str, str | list[str]] = {
"dataset_path": self.DATASET_PATH,
"sample_split": self.SAMPLE_SPLIT,
"fewshot_split": self.FEWSHOT_SPLIT,
"response_type": response_type.value,
"metrics": [m.NAME for m in metrics],
"response_type": self.get_response_type().value,
"metrics": [m.NAME for m in self.get_metrics()],
"subjects": [str(s) for s in self.SUBJECTS],
}
if hasattr(self, "TASK_STYLER"):
Expand Down Expand Up @@ -420,7 +420,26 @@ def generate_completions(
)
return completion_list

def _get_type_and_metrics(self) -> tuple[ResponseType, list[type["BaseMetric"]]]:
if hasattr(self, "TASK_STYLER"):
return self.TASK_STYLER.response_type, self.TASK_STYLER.metrics
return self.RESPONSE_TYPE, self.METRICS
@classmethod
def get_response_type(cls) -> ResponseType:
"""Return the response type of the task (or the styler if it exists)."""
if hasattr(cls, "TASK_STYLER"):
return cls.TASK_STYLER.response_type
return cls.RESPONSE_TYPE

@classmethod
def get_metrics(cls) -> list[type["BaseMetric"]]:
"""Return the metrics of the task (or the styler if it exists)."""
if hasattr(cls, "TASK_STYLER"):
return cls.TASK_STYLER.metrics
return cls.METRICS

@classproperty
def RESPONSE_TYPE(cls) -> ResponseType:
"""For backwards compatibility."""
return cls.get_response_type()

@classproperty
def METRICS(cls) -> list[type["BaseMetric"]]:
"""For backwards compatibility."""
return cls.get_metrics()
2 changes: 1 addition & 1 deletion src/eval_framework/tasks/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def validate_judge_model_args(cls, value: dict[str, Any]) -> dict[str, Any]:
@model_validator(mode="after")
def validate_llm_judge_defined(self) -> "EvalConfig":
task = get_task(self.task_name)
_, task_metrics = task(num_fewshot=0)._get_type_and_metrics()
task_metrics = task(num_fewshot=0).get_metrics()
for metric_class in task_metrics:
if issubclass(metric_class, BaseLLMJudgeMetric):
assert self.llm_judge_class is not None, "The LLM Judge must be defined for this evaluation task."
Expand Down
20 changes: 19 additions & 1 deletion src/eval_framework/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import threading
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal, NamedTuple
from typing import Any, Literal, NamedTuple, overload

import dill
import numpy as np
Expand All @@ -22,6 +22,24 @@
RANDOM_SEED = 42 # hacky way to get around circular import
redis_warning_printed = False


class classproperty[T]:
"""Descriptor supporting property-like access on classes and instances."""

def __init__(self, fget: Callable[[Any], T]) -> None:
self.fget = fget

@overload
def __get__(self, obj: None, owner: type[Any]) -> T: ...

@overload
def __get__(self, obj: object, owner: type[Any] | None = None) -> T: ...

def __get__(self, obj: object | None, owner: type[Any] | None = None) -> T:
cls = owner if owner is not None else type(obj)
return self.fget(cls)


_pools: dict[tuple[str | None, tuple[str, ...] | None], ContainerPoolManager] = {}
_pools_lock = threading.Lock()

Expand Down
8 changes: 8 additions & 0 deletions tests/tests_eval_framework/tasks/test_task_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,11 @@ def test_metadata_task_style(self) -> None:
def test_metadata_metrics_bpb_only(self) -> None:
meta = self.task.get_metadata()
assert meta["metrics"] == ["BitsPerByte"]


def test_instance_properties_are_styler_backed() -> None:
task = _ConcreteMCTask()

# Check compatibility access points for metadata.
assert task.RESPONSE_TYPE == ResponseType.LOGLIKELIHOODS
assert task.METRICS == task.TASK_STYLER.metrics
Loading