Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/_pull_requests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ jobs:

steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v7.0.0

- name: Set up Miniconda
uses: conda-incubator/setup-miniconda@v3
uses: conda-incubator/setup-miniconda@v4.0.1
with:
activate-environment: graphrag-eval
environment-file: environment.yml
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v7.0.0

- name: Set up Miniconda
uses: conda-incubator/setup-miniconda@v3
uses: conda-incubator/setup-miniconda@v4.0.1
with:
activate-environment: graphrag-eval
environment-file: environment.yml
Expand Down
6 changes: 6 additions & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
`run_evaluation()` and `compute_aggregates()` are configured using a YAML file whose path is passed as the optional parameter `config_file_path`. Example call:

```python
from graphrag_eval import run_evaluation


evaluation_results = await run_evaluation(
reference_dataset,
target_dataset,
Expand Down Expand Up @@ -43,6 +46,7 @@ The configuration has the following structure:
- `instructions`: (`str`) instructions for the evaluation
- `outputs`: (`map[str,str]`) output variable names and descriptions
- `answer_correctness`
- `enabled` (`bool`, default: `True`) - if `False`, then [answer correctness metrics](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md) won't be calculated.
- `prompt`: (`str`, default [here](https://github.com/Ontotext-AD/graphrag-eval/blob/main/graphrag_eval/prompts/template.md))
Template for instructions to the LLM on how to compute the answer correctness.
Must contain placeholders `{question}`, `{reference_answer}` and `{actual_answer}` and no others.
Expand All @@ -52,6 +56,8 @@ The configuration has the following structure:
# Output format
<v1><tab><v2><tab><v3><tab><v4>
```
- `answer_relevance`
- `enabled` (`bool`, default: `True`) - if `False`, then [answer relevance metric](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md) won't be calculated.

## Example configuration file with LLM configuration

Expand Down
3 changes: 3 additions & 0 deletions docs/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ See [§ Example configuration file with custom evaluation](https://github.com/On
### Example call to evaluate using custom metrics

```python
from graphrag_eval import run_evaluation, compute_aggregates


evaluation_results = await run_evaluation(
reference_qa_dataset,
chat_responses,
Expand Down
2 changes: 1 addition & 1 deletion docs/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The output is a list of objects, one for each reference item. Each output object
- `answer_recall`: (optional) `answer_matching_claims_count / answer_reference_claims_count`
- `answer_precision`: (optional) `answer_matching_claims_count / answer_actual_claims_count`
- `answer_correctness_reason`: (optional) LLM reasoning in extracting and matching claims from `reference_answer` and `actual_answer`
- `answer_eval_error`: (optional) error message if answer evaluation failed
- `answer_correctness_error`: (optional) error message if answer correctness evaluation failed
- `answer_f1`: (optional) Harmonic mean of `answer_recall` and `answer_precision`
- `answer_relevance`: (optional `float` in [0, 1]) answer relevance score
- `answer_relevance_error`: (optional) error message if answer relevance evaluation failed
Expand Down
8 changes: 4 additions & 4 deletions graphrag_eval/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import yaml
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from statistics import mean, median
from typing import Any, Collection, Iterable

from . import evaluation
import yaml

from . import evaluation

METRICS = [
"answer_recall",
Expand Down Expand Up @@ -155,7 +155,7 @@ def compute_micro_stats(
) -> dict:
if custom_metrics is None:
custom_metrics = []

values = number_of_samples_per_template_by_status.values()
micro_summary = defaultdict(dict, {
"number_of_error_samples": sum(v["error"] for v in values),
Expand Down Expand Up @@ -197,7 +197,7 @@ def compute_macro_stats(
) -> dict:
if custom_metrics is None:
custom_metrics = []

macro_summary = defaultdict(dict)
for metric in METRICS + custom_metrics:
means = [
Expand Down
67 changes: 52 additions & 15 deletions graphrag_eval/answer_correctness.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,62 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Self, TYPE_CHECKING

from pydantic import BaseModel, Field

from graphrag_eval.util import compute_f1
from .evaluator import Evaluator

if TYPE_CHECKING:
from ragas.llms.base import InstructorBaseRagasLLM


def load_default_prompt() -> str:
with open(Path(__file__).parent / "prompts" / "template.md", "r", encoding="utf-8") as f:
with open(
Path(__file__).parent / "prompts" / "template.md",
encoding="utf-8"
) as f:
return f.read()


class AnswerCorrectnessConfig(BaseModel):
Comment thread
nelly-hateva marked this conversation as resolved.
enabled: bool = Field(default=True)
prompt: str = Field(default_factory=load_default_prompt)


class InvalidPromptException(Exception):
def __init__(self, message="The prompt template is invalid and cannot be formatted."):
def __init__(
self,
message="The prompt template is invalid and cannot be "
"formatted."
):
self.message = message
super().__init__(self.message)


class AnswerCorrectnessEvaluator:
def __init__(
self,
llm: "InstructorBaseRagasLLM",
ragas_llm: InstructorBaseRagasLLM,
config: AnswerCorrectnessConfig | None = None,
):
self.config = config or AnswerCorrectnessConfig()
self.__validate_prompt_template(self.config.prompt)
self.prompt_template = self.config.prompt
self.llm = llm
self.ragas_llm = ragas_llm

@classmethod
def from_config(
cls,
ragas_llm: InstructorBaseRagasLLM | None,
config: AnswerCorrectnessConfig | None
) -> Self | None:
if ragas_llm is None:
return None
if config is None or not config.enabled:
return None
return cls(ragas_llm=ragas_llm, config=config)

@staticmethod
def __validate_prompt_template(prompt_template: str):
Expand All @@ -48,17 +75,21 @@ def __validate_prompt_template(prompt_template: str):

async def _agenerate(self, prompt):
"""Wrapper method for easier testing"""
return (await self.llm.agenerate(prompt, None)).choices[0].message.content
return (await self.ragas_llm.agenerate(prompt, None)).choices[0].message.content

async def evaluate_answer(
self,
question: str,
reference_answer: str,
actual_answer: str
) -> tuple[int, int, int, str]:
if any(not s.strip() for s in [question, reference_answer, actual_answer]):
raise ValueError("The question of the reference or the actual answer is a blank "
"string!")
if any(
not s.strip() for s in [question, reference_answer, actual_answer]
):
raise ValueError(
"The question of the reference or the actual answer is a blank "
"string!"
)
prompt = self.prompt_template.format(
question=question,
reference_answer=reference_answer,
Expand All @@ -67,12 +98,14 @@ async def evaluate_answer(
response_str = await self._agenerate(prompt)
return self.extract_response_values(response_str)

async def get_correctness_dict(
async def evaluate(
self,
reference: dict,
actual: dict,
):
result = {"reference_answer": reference["reference_answer"]}
reference: dict[str, Any],
actual: dict[str, Any]
) -> dict[str, Any]:
if "actual_answer" not in actual or "reference_answer" not in reference:
return {}
result = {}
try:
num_ref_claims, num_actual_claims, num_matching_claims, reason = \
await self.evaluate_answer(
Expand All @@ -96,7 +129,7 @@ async def get_correctness_dict(
if f1 is not None:
result["answer_f1"] = f1
except Exception as exc:
result["answer_eval_error"] = str(exc)
result["answer_correctness_error"] = str(exc)
return result

@staticmethod
Expand Down Expand Up @@ -134,6 +167,10 @@ def extract_response_values(
n_matching > n_actual
]):
raise ValueError(
f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}"
"Invalid claims counts combination: "
f"{n_ref}\t{n_actual}\t{n_matching}"
)
return n_ref, n_actual, n_matching, vals[3]


_: Evaluator = AnswerCorrectnessEvaluator
Comment thread
nelly-hateva marked this conversation as resolved.
58 changes: 45 additions & 13 deletions graphrag_eval/answer_relevance.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,53 @@
from ragas.embeddings.base import BaseRagasEmbedding
from ragas.llms.base import InstructorBaseRagasLLM
from ragas.metrics.collections import AnswerRelevancy
from __future__ import annotations

from graphrag_eval.util import singleton
from typing import Any, Self, TYPE_CHECKING

from pydantic import BaseModel, Field

@singleton
class Evaluator:
def __init__(self, ragas_llm: InstructorBaseRagasLLM, ragas_embedder: BaseRagasEmbedding):
from .evaluator import Evaluator

if TYPE_CHECKING:
from ragas.llms.base import InstructorBaseRagasLLM
from ragas.embeddings.base import BaseRagasEmbeddings, BaseRagasEmbedding


class AnswerRelevanceConfig(BaseModel):
enabled: bool = Field(default=True)


class AnswerRelevanceEvaluator:
def __init__(
self,
ragas_llm: InstructorBaseRagasLLM,
ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding
):
from ragas.metrics.collections import AnswerRelevancy
self.scorer = AnswerRelevancy(llm=ragas_llm, embeddings=ragas_embedder)

async def get_relevance_dict(
@classmethod
def from_config(
cls,
ragas_llm: InstructorBaseRagasLLM | None,
ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding | None,
config: AnswerRelevanceConfig | None
) -> Self | None:
if ragas_llm is None or ragas_embedder is None:
return None
if config is None or not config.enabled:
return None
return cls(ragas_llm=ragas_llm, ragas_embedder=ragas_embedder)

async def evaluate(
self,
question_text: str,
actual_answer: str,
) -> dict:
reference: dict[str, Any],
actual: dict[str, Any]
) -> dict[str, Any]:
if "actual_answer" not in actual:
return {}
try:
result = await self.scorer.ascore(
user_input=question_text,
response=actual_answer
user_input=reference["question_text"],
response=actual["actual_answer"]
)
return {
"answer_relevance": result.value
Expand All @@ -27,3 +56,6 @@ async def get_relevance_dict(
return {
"answer_relevance_error": str(e)
}


_: Evaluator = AnswerRelevanceEvaluator
Loading
Loading